@@ -35,7 +35,7 @@ const char* NORMALIZE_PLUGIN_NAME{"Normalize_TRT"};
3535PluginFieldCollection NormalizePluginCreator::mFC {};
3636std::vector<PluginField> NormalizePluginCreator::mPluginAttributes ;
3737
38- Normalize::Normalize (const Weights* weights, int nbWeights, bool acrossSpatial, bool channelShared, float eps)
38+ Normalize::Normalize (Weights const * weights, int nbWeights, bool acrossSpatial, bool channelShared, float eps)
3939 : acrossSpatial(acrossSpatial)
4040 , channelShared(channelShared)
4141 , eps(eps)
@@ -44,11 +44,13 @@ Normalize::Normalize(const Weights* weights, int nbWeights, bool acrossSpatial,
4444 PLUGIN_VALIDATE (nbWeights == 1 );
4545 PLUGIN_VALIDATE (weights[0 ].count >= 1 );
4646 mWeights = copyToDevice (weights[0 ].values , weights[0 ].count );
47+ mScalarScale = static_cast <float const *>(weights[0 ].values )[0 ];
4748}
4849
4950Normalize::Normalize (
50- const Weights* weights, int nbWeights, bool acrossSpatial, bool channelShared, float eps, int C, int H, int W)
51- : acrossSpatial(acrossSpatial)
51+ Weights const * weights, int nbWeights, float scalarScale, bool acrossSpatial, bool channelShared, float eps, int C, int H, int W)
52+ : mScalarScale(scalarScale)
53+ , acrossSpatial(acrossSpatial)
5254 , channelShared(channelShared)
5355 , eps(eps)
5456 , C(C)
@@ -74,6 +76,7 @@ Normalize::Normalize(const void* buffer, size_t length)
7476
7577 mNbWeights = read<int >(d);
7678 int count = read<int >(d);
79+ std::memcpy (&mScalarScale , d, sizeof (float ));
7780 mWeights = deserializeToDevice (d, count);
7881 PLUGIN_VALIDATE (d == a + length);
7982}
@@ -111,8 +114,19 @@ int Normalize::enqueue(
111114{
112115 const void * inputData = inputs[0 ];
113116 void * outputData = outputs[0 ];
114- pluginStatus_t status = normalizeInference (stream, mCublas , acrossSpatial, channelShared, batchSize, C, H, W, eps,
115- static_cast <const float *>(mWeights .values ), inputData, outputData, workspace);
117+
118+ pluginStatus_t status;
119+
120+ if (acrossSpatial && channelShared) // Since cublasPointerMode_t is CUBLAS_POINTER_MODE_HOST, scale should be on the host
121+ {
122+ status = normalizeInference (stream, mCublas , acrossSpatial, channelShared, batchSize, C, H, W, eps,
123+ &mScalarScale , inputData, outputData, workspace);
124+ }
125+ else // No risk of device pointers being passed to cublas as alpha or beta
126+ {
127+ status = normalizeInference (stream, mCublas , acrossSpatial, channelShared, batchSize, C, H, W, eps,
128+ static_cast <float const *>(mWeights .values ), inputData, outputData, workspace);
129+ }
116130
117131 return status;
118132}
@@ -254,7 +268,7 @@ IPluginV2Ext* Normalize::clone() const noexcept
254268 try
255269 {
256270 // Create a new instance
257- IPluginV2Ext* plugin = new Normalize (&mWeights , mNbWeights , acrossSpatial, channelShared, eps, C, H, W);
271+ IPluginV2Ext* plugin = new Normalize (&mWeights , mNbWeights , mScalarScale , acrossSpatial, channelShared, eps, C, H, W);
258272
259273 // Set the namespace
260274 plugin->setPluginNamespace (mPluginNamespace .c_str ());
0 commit comments