@@ -65,7 +65,7 @@ void CheckNoopTensor(const Tensor &t, const std::string &name) {
6565 }
6666}
6767
68- void CheckScaleTensorShape (const Tensor &t, bool check_scale_inv_alignment ) {
68+ void CheckScaleTensorShape (const Tensor &t) {
6969 NVTE_CHECK (t.scaling_mode != NVTE_INVALID_SCALING, " Invalid scaling mode!" );
7070 if (is_tensor_scaling (t.scaling_mode )) {
7171 // per-tensor scaling
@@ -80,7 +80,6 @@ void CheckScaleTensorShape(const Tensor &t, bool check_scale_inv_alignment) {
8080 }
8181 } else {
8282 if (t.scaling_mode == NVTE_MXFP8_1D_SCALING) {
83- if (!check_scale_inv_alignment) return ;
8483 // Need (4, 128) alignment even for e8 scaling factor
8584 auto block_alignment = std::vector<size_t >{128ul / typeToSize (t.scale_inv .dtype ),
8685 4ul / typeToSize (t.scale_inv .dtype )};
@@ -111,7 +110,7 @@ void CheckScaleTensorShape(const Tensor &t, bool check_scale_inv_alignment) {
111110 }
112111}
113112
114- void CheckInputTensor (const Tensor &t, const std::string &name, bool check_scale_inv_alignment ) {
113+ void CheckInputTensor (const Tensor &t, const std::string &name) {
115114 const DType type = t.dtype ();
116115 if (is_fp8_dtype (type)) {
117116 // FP8 input needs to have scale_inv
@@ -143,11 +142,10 @@ void CheckInputTensor(const Tensor &t, const std::string &name, bool check_scale
143142 }
144143 NVTE_CHECK (t.has_data () || t.has_columnwise_data (), " Input " , name, " is not allocated!" );
145144
146- CheckScaleTensorShape (t, check_scale_inv_alignment );
145+ CheckScaleTensorShape (t);
147146}
148147
149- void CheckOutputTensor (const Tensor &t, const std::string &name, bool allow_empty,
150- bool check_scale_inv_alignment) {
148+ void CheckOutputTensor (const Tensor &t, const std::string &name, bool allow_empty) {
151149 const DType type = t.dtype ();
152150 if (is_fp8_dtype (type)) {
153151 // FP8 output needs to have scale, scale_inv and (if delayed scaling) amax
@@ -189,7 +187,7 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
189187 NVTE_CHECK (t.has_data () || t.has_columnwise_data (), " Output " , name, " is not allocated!" );
190188 }
191189
192- CheckScaleTensorShape (t, check_scale_inv_alignment );
190+ CheckScaleTensorShape (t);
193191}
194192
195193} // namespace transformer_engine
0 commit comments