Skip to content

Commit 20361eb

Browse files
oplavsicOgnjen Plavsicantiagainst
authored
Fold fp_to_fp op with zero constant input (#5007)
Fold fp_to_fp op with a zero constant input into a zero constant with fp_to_fp op destination type. --------- Co-authored-by: Ognjen Plavsic <[email protected]> Co-authored-by: Lei Zhang <[email protected]>
1 parent a4fd8c6 commit 20361eb

File tree

3 files changed

+96
-0
lines changed

3 files changed

+96
-0
lines changed

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
108108
let assemblyFormat = "$src attr-dict (`,` `rounding` `=` $rounding^)? `:` type($src) `->` type($result)";
109109

110110
let hasVerifier = 1;
111+
112+
let hasFolder = 1;
111113
}
112114

113115
//

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,29 @@ LogicalResult ReshapeOp::verify() {
728728
}
729729

730730
//-- FpToFpOp --
731+
732+
// Fold FpToFpOp when the input operand is a constant zero.
733+
OpFoldResult FpToFpOp::fold(FoldAdaptor adaptor) {
734+
auto srcVal = getSrc();
735+
auto dstTy = getType();
736+
737+
const llvm::fltSemantics &semantic =
738+
llvm::cast<FloatType>(dstTy.getElementType()).getFloatSemantics();
739+
740+
if (matchPattern(srcVal, m_PosZeroFloat())) {
741+
llvm::APFloat posZero =
742+
llvm::APFloat::getZero(semantic, /*negative=*/false);
743+
return DenseFPElementsAttr::get(dstTy, posZero);
744+
}
745+
746+
if (matchPattern(srcVal, m_NegZeroFloat())) {
747+
llvm::APFloat negZero = llvm::APFloat::getZero(semantic, /*negative=*/true);
748+
return DenseFPElementsAttr::get(dstTy, negZero);
749+
}
750+
751+
return {};
752+
}
753+
731754
LogicalResult FpToFpOp::verify() {
732755
auto dstType = getType().getElementType();
733756
auto srcType = getSrc().getType().getElementType();

test/Triton/canonicalize.mlir

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,74 @@ tt.func @fn(%arg0: tensor<1xf32, #sliced0>) -> (tensor<32x1xf32, #blocked0>){
5050
tt.return %b : tensor<32x1xf32, #blocked0>
5151
}
5252
} // end module
53+
54+
// -----
55+
56+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
57+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
58+
tt.func @fp_to_fp_pos_zero_fold() -> tensor<32x128xf8E4M3FNUZ, #blocked> {
59+
// CHECK-LABEL: fp_to_fp_pos_zero_fold
60+
// CHECK-NEXT: %[[cst_folded:.+]] = arith.constant dense<0.000000e+00> : tensor<32x128xf8E4M3FNUZ, #blocked>
61+
// CHECK-NEXT: tt.return %[[cst_folded]]
62+
%cst = arith.constant dense<0.00e+00> : tensor<32x128xf32, #blocked>
63+
%cst_converted = tt.fp_to_fp %cst, rounding = rtne : tensor<32x128xf32, #blocked> -> tensor<32x128xf8E4M3FNUZ, #blocked>
64+
tt.return %cst_converted : tensor<32x128xf8E4M3FNUZ, #blocked>
65+
}
66+
} // end module
67+
68+
// -----
69+
70+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
71+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
72+
tt.func @fp_to_fp_neg_zero_fold() -> tensor<32x128xf8E4M3FN, #blocked> {
73+
// CHECK-LABEL: fp_to_fp_neg_zero_fold
74+
// CHECK-NEXT: %[[cst_folded:.+]] = arith.constant dense<-0.000000e+00> : tensor<32x128xf8E4M3FN, #blocked>
75+
// CHECK-NEXT: tt.return %[[cst_folded]]
76+
%cst = arith.constant dense<-0.00e+00> : tensor<32x128xf32, #blocked>
77+
%cst_converted = tt.fp_to_fp %cst, rounding = rtne : tensor<32x128xf32, #blocked> -> tensor<32x128xf8E4M3FN, #blocked>
78+
tt.return %cst_converted : tensor<32x128xf8E4M3FN, #blocked>
79+
}
80+
} // end module
81+
82+
// -----
83+
84+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
85+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
86+
tt.func @fp_to_fp_neg_zero_fold() -> tensor<32x128xf8E4M3FNUZ, #blocked> {
87+
// CHECK-LABEL: fp_to_fp_neg_zero_fold
88+
// We fold to the positive zero here given by definition f8E4M3FNUZ does not have negative zero encoding.
89+
// CHECK-NEXT: %[[cst_folded:.+]] = arith.constant dense<0.000000e+00> : tensor<32x128xf8E4M3FNUZ, #blocked>
90+
// CHECK-NEXT: tt.return %[[cst_folded]]
91+
%cst = arith.constant dense<-0.00e+00> : tensor<32x128xf32, #blocked>
92+
%cst_converted = tt.fp_to_fp %cst, rounding = rtne : tensor<32x128xf32, #blocked> -> tensor<32x128xf8E4M3FNUZ, #blocked>
93+
tt.return %cst_converted : tensor<32x128xf8E4M3FNUZ, #blocked>
94+
}
95+
} // end module
96+
97+
// -----
98+
99+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
100+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
101+
tt.func @fold_fp_to_fp_non_zero_nofold() -> tensor<32x128xf8E4M3FNUZ, #blocked> {
102+
// CHECK-LABEL: fold_fp_to_fp_non_zero_nofold
103+
// CHECK-NEXT: %[[cst:.+]] = arith.constant dense<0xFF800000> : tensor<32x128xf32, #blocked>
104+
// CHECK-NEXT: %[[cst_cvt:.+]] = tt.fp_to_fp %[[cst]]
105+
// CHECK-NEXT: tt.return %[[cst_cvt]]
106+
%cst = arith.constant dense<0xFF800000> : tensor<32x128xf32, #blocked>
107+
%cst_converted = tt.fp_to_fp %cst, rounding = rtne : tensor<32x128xf32, #blocked> -> tensor<32x128xf8E4M3FNUZ, #blocked>
108+
tt.return %cst_converted : tensor<32x128xf8E4M3FNUZ, #blocked>
109+
}
110+
} // end module
111+
112+
// -----
113+
114+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
115+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
116+
tt.func @fold_fp_to_fp_non_constant_nofold(%arg0: tensor<32x128xf32, #blocked>) -> tensor<32x128xf8E4M3FNUZ, #blocked> {
117+
// CHECK-LABEL: fold_fp_to_fp_non_constant_nofold
118+
// CHECK-NEXT: %[[arg_cvt:.+]] = tt.fp_to_fp %arg0
119+
// CHECK-NEXT: tt.return %[[arg_cvt]]
120+
%cst_converted = tt.fp_to_fp %arg0, rounding = rtne : tensor<32x128xf32, #blocked> -> tensor<32x128xf8E4M3FNUZ, #blocked>
121+
tt.return %cst_converted : tensor<32x128xf8E4M3FNUZ, #blocked>
122+
}
123+
} // end module

0 commit comments

Comments
 (0)