Skip to content

Commit 80061d6

Browse files
committed
address review comments and add tests
1 parent 45e7dba commit 80061d6

File tree

3 files changed

+55
-13
lines changed

3 files changed

+55
-13
lines changed

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,9 +1227,9 @@ def Arith_ScalingExtFOp
12271227
OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
12281228
Results<(outs FloatLike:$out)> {
12291229
let summary =
1230-
"Upcasts quantized floats using provided scales values following OCP MXFP Spec";
1230+
"Upcasts input floats using provided scales values following OCP MXFP Spec";
12311231
let description = [{
1232-
This operation upcasts quantized floating-point values using provided scale
1232+
This operation upcasts input floating-point values using provided scale
12331233
values. It expects both scales and the input operand to be of the same shape,
12341234
making the operation elementwise. Scales are usually calculated per block
12351235
following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
@@ -1253,7 +1253,6 @@ def Arith_ScalingExtFOp
12531253
resultTy = get_type(result)
12541254
scaleTy = get_type(scale)
12551255
inputTy = get_type(input)
1256-
assert(scaleTy.shape() == inputTy.shape() == resultTy.shape())
12571256
scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
12581257
scale.extf = arith.extf(scale.exponent) : f8E8M0 to resultTy
12591258
input.extf = arith.extf(input) : inputTy to resultTy
@@ -1350,7 +1349,7 @@ def Arith_ScalingTruncFOp
13501349
let summary =
13511350
"Downcasts input floating point values using provided scales values following OCP MXFP Spec";
13521351
let description = [{
1353-
This operation quantizes input using the provided scale values. It expects
1352+
This operation downcasts input using the provided scale values. It expects
13541353
both scales and the input operand to be of the same shape and, therefore,
13551354
makes the operation elementwise. Scales are usually calculated per block
13561355
following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
@@ -1378,7 +1377,6 @@ def Arith_ScalingTruncFOp
13781377
scaleTy = get_type(scale)
13791378
inputTy = get_type(input)
13801379
resultTy = get_type(result)
1381-
assert(scaleTy.shape() == inputTy.shape() == resultTy.shape())
13821380
scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
13831381
scale.extf = arith.extf(scale.exponent) : f8E8M0 to inputTy
13841382
result = arith.divf(input, scale.extf)

mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
359359
result = b.create<arith::TruncFOp>(resultTy, result, nullptr,
360360
op.getFastmathAttr());
361361
} else if (resultETy.getIntOrFloatBitWidth() > 32) {
362-
result = b.create<arith::ExtFOp>(resultTy, result);
362+
result = b.create<arith::ExtFOp>(resultTy, result, op.getFastmathAttr());
363363
}
364364
rewriter.replaceOp(op, result);
365365
return success();
@@ -417,11 +417,13 @@ struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
417417
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
418418
Value inputOperand = op.getIn();
419419
Value scaleOperand = op.getScale();
420+
Type scaleTy = scaleOperand.getType();
420421
Type scaleETy = getElementTypeOrSelf(scaleOperand);
421422
// allow implicit exponent extraction from 16/32 bits floats
422423
if (scaleETy.getIntOrFloatBitWidth() >= 16) {
423424
scaleETy = b.getF8E8M0Type();
424-
scaleOperand = b.create<arith::TruncFOp>(scaleETy, scaleOperand, nullptr,
425+
scaleTy = cloneToShapedType(scaleTy, scaleETy);
426+
scaleOperand = b.create<arith::TruncFOp>(scaleTy, scaleOperand, nullptr,
425427
op.getFastmathAttr());
426428
}
427429
if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
@@ -461,8 +463,9 @@ struct ScalingTruncFOpConverter
461463
// allow implicit exponent extraction from 16/32 bits floats
462464
if (scaleETy.getIntOrFloatBitWidth() >= 16) {
463465
scaleETy = b.getF8E8M0Type();
464-
scaleOperand = b.create<arith::TruncFOp>(scaleETy, scaleOperand);
465-
scaleTy = scaleOperand.getType();
466+
scaleTy = cloneToShapedType(scaleTy, scaleETy);
467+
scaleOperand = b.create<arith::TruncFOp>(scaleTy, scaleOperand, nullptr,
468+
op.getFastmathAttr());
466469
}
467470
if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
468471
return rewriter.notifyMatchFailure(

mlir/test/Dialect/Arith/expand-ops.mlir

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -336,15 +336,19 @@ func.func @scaling_truncf_vector_f16_to_f6E3M2FN(%arg0 : vector<4xf16>, %arg1: v
336336

337337
// -----
338338

339-
func.func @scaling_truncf_propagate_rounding_mode(%arg0 : vector<4xf16>, %arg1: vector<4xf8E8M0FNU>) -> vector<4xf6E3M2FN> {
340-
%0 = arith.scaling_truncf %arg0, %arg1 to_nearest_even : vector<4xf16>, vector<4xf8E8M0FNU> to vector<4xf6E3M2FN>
339+
func.func @scaling_truncf_propagate_rounding_mode_fast_math(%arg0 : vector<4xf16>, %arg1: vector<4xf16>) -> vector<4xf6E3M2FN> {
340+
%0 = arith.scaling_truncf %arg0, %arg1 to_nearest_even fastmath<fast> : vector<4xf16>, vector<4xf16> to vector<4xf6E3M2FN>
341341
return %0 : vector<4xf6E3M2FN>
342342
}
343-
// SCHECK-LABEL: @scaling_truncf_propagate_rounding_mode
344-
// SCHECK: %[[TRUNCF:.+]] = arith.truncf [[_:%[a-zA-Z0-9_]+]] to_nearest_even : vector<4xf16> to vector<4xf6E3M2FN>
343+
// SCHECK-LABEL: @scaling_truncf_propagate_rounding_mode_fast_math
344+
// SCHECK: %[[SCALEF8:.+]] = arith.truncf %arg1 fastmath<fast> : vector<4xf16> to vector<4xf8E8M0FNU>
345+
// SCHECK: %[[SCALEINTY:.+]] = arith.extf %[[SCALEF8]] fastmath<fast> : vector<4xf8E8M0FNU> to vector<4xf16>
346+
// SCHECK: %[[DIVF:.+]] = arith.divf %arg0, %[[SCALEINTY]] fastmath<fast> : vector<4xf16>
347+
// SCHECK: %[[TRUNCF:.+]] = arith.truncf [[_:%[a-zA-Z0-9_]+]] to_nearest_even fastmath<fast> : vector<4xf16> to vector<4xf6E3M2FN>
345348
// SCHECK: return %[[TRUNCF]] : vector<4xf6E3M2FN>
346349

347350
// -----
351+
348352
func.func @scaling_truncf_f16_to_f4E2M1FN_using_f16_scales(%arg0: f16, %arg1 : f16) -> f4E2M1FN {
349353
%0 = arith.scaling_truncf %arg0, %arg1 : f16, f16 to f4E2M1FN
350354
return %0 : f4E2M1FN
@@ -353,6 +357,15 @@ func.func @scaling_truncf_f16_to_f4E2M1FN_using_f16_scales(%arg0: f16, %arg1 : f
353357
// SCHECK: %[[SCALETRUNCF:.+]] = arith.truncf %arg1 : f16 to f8E8M0FN
354358
// SCHECK: return
355359

360+
// -----
361+
func.func @scaling_truncf_vector_f16_to_f4E2M1FN_using_f16_scales(%arg0: vector<4xf16>, %arg1 : vector<4xf16>) -> vector<4xf4E2M1FN> {
362+
%0 = arith.scaling_truncf %arg0, %arg1 : vector<4xf16>, vector<4xf16> to vector<4xf4E2M1FN>
363+
return %0 : vector<4xf4E2M1FN>
364+
}
365+
// SCHECK-LABEL: @scaling_truncf_vector_f16_to_f4E2M1FN_using_f16_scales
366+
// SCHECK: %[[SCALETRUNCF:.+]] = arith.truncf %arg1 : vector<4xf16> to vector<4xf8E8M0FNU>
367+
// SCHECK: return
368+
356369
// -----
357370

358371
func.func @invalid_scaling_truncf_to_f4E2M1FN(%arg0: f16, %arg1 : f8E5M2FNUZ) -> f4E2M1FN {
@@ -507,6 +520,34 @@ func.func @scaling_extf_vector_to_bf16(%arg0: vector<4xf4E2M1FN>, %arg1 : vector
507520

508521
// -----
509522

523+
func.func @scaling_extf_vector_to_f32_using_f16_scales(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf16>) -> vector<4xf32> {
524+
%0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf16> to vector<4xf32>
525+
return %0 : vector<4xf32>
526+
}
527+
528+
// SCHECK-LABEL: @scaling_extf_vector_to_f32_using_f16_scales
529+
// SCHECK: %[[TRUNCF_SCALE:.+]] = arith.truncf %arg1 : vector<4xf16> to vector<4xf8E8M0FNU>
530+
// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %[[TRUNCF_SCALE]] : vector<4xf8E8M0FNU> to vector<4xf32>
531+
// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf32>
532+
// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xf32>
533+
// SCHECK: return %[[RESULT]]
534+
535+
// -----
536+
537+
func.func @scaling_extf_vector_to_f32_using_f16_scales_fastmath(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf16>) -> vector<4xf32> {
538+
%0 = arith.scaling_extf %arg0, %arg1 fastmath<fast> : vector<4xf4E2M1FN>, vector<4xf16> to vector<4xf32>
539+
return %0 : vector<4xf32>
540+
}
541+
542+
// SCHECK-LABEL: @scaling_extf_vector_to_f32_using_f16_scales_fastmath
543+
// SCHECK: %[[TRUNCF_SCALE:.+]] = arith.truncf %arg1 fastmath<fast> : vector<4xf16> to vector<4xf8E8M0FNU>
544+
// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %[[TRUNCF_SCALE]] fastmath<fast> : vector<4xf8E8M0FNU> to vector<4xf32>
545+
// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 fastmath<fast> : vector<4xf4E2M1FN> to vector<4xf32>
546+
// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] fastmath<fast> : vector<4xf32>
547+
// SCHECK: return %[[RESULT]]
548+
549+
// -----
550+
510551
func.func @maxsi(%a: i32, %b: i32) -> i32 {
511552
%result = arith.maxsi %a, %b : i32
512553
return %result : i32

0 commit comments

Comments
 (0)