Skip to content

Commit d556ce9

Browse files
authored
[BACKEND] Support scalar fp_to_fp (#5132)
Now that fp_to_fp is marked as elementwise we may have scalar version of this op.
1 parent 385671e commit d556ce9

File tree

6 files changed

+34
-11
lines changed

6 files changed

+34
-11
lines changed

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,11 @@ def TT_FpToFpOp : TT_Op<"fp_to_fp", [Elementwise,
100100
}];
101101

102102
let arguments = (
103-
ins TT_FloatTensor:$src,
103+
ins TT_FloatLike:$src,
104104
OptionalAttr<TT_RoundingModeAttr>:$rounding
105105
);
106106

107-
let results = (outs TT_FloatTensor:$result);
107+
let results = (outs TT_FloatLike:$result);
108108

109109
let assemblyFormat = "$src attr-dict (`,` `rounding` `=` $rounding^)? `:` type($src) `->` type($result)";
110110

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -734,26 +734,34 @@ OpFoldResult FpToFpOp::fold(FoldAdaptor adaptor) {
734734
auto srcVal = getSrc();
735735
auto dstTy = getType();
736736

737-
const llvm::fltSemantics &semantic =
738-
llvm::cast<FloatType>(dstTy.getElementType()).getFloatSemantics();
737+
auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
738+
const llvm::fltSemantics &semantic = resElemType.getFloatSemantics();
739739

740740
if (matchPattern(srcVal, m_PosZeroFloat())) {
741741
llvm::APFloat posZero =
742742
llvm::APFloat::getZero(semantic, /*negative=*/false);
743-
return DenseFPElementsAttr::get(dstTy, posZero);
743+
if (auto tensorTy = dyn_cast<RankedTensorType>(dstTy))
744+
return DenseElementsAttr::get(tensorTy, posZero);
745+
return Builder(getContext()).getFloatAttr(resElemType, posZero);
744746
}
745747

746748
if (matchPattern(srcVal, m_NegZeroFloat())) {
747749
llvm::APFloat negZero = llvm::APFloat::getZero(semantic, /*negative=*/true);
748-
return DenseFPElementsAttr::get(dstTy, negZero);
750+
if (auto tensorTy = dyn_cast<RankedTensorType>(dstTy))
751+
return DenseElementsAttr::get(tensorTy, negZero);
752+
return Builder(getContext()).getFloatAttr(resElemType, negZero);
749753
}
750754

751755
return {};
752756
}
753757

754758
LogicalResult FpToFpOp::verify() {
755-
auto dstType = getType().getElementType();
756-
auto srcType = getSrc().getType().getElementType();
759+
auto dstType = getType();
760+
auto srcType = getSrc().getType();
761+
if (auto dstTensorType = dyn_cast<RankedTensorType>(dstType))
762+
dstType = dstTensorType.getElementType();
763+
if (auto srcTensorType = dyn_cast<RankedTensorType>(srcType))
764+
srcType = srcTensorType.getElementType();
757765
if ((dstType.getIntOrFloatBitWidth() < srcType.getIntOrFloatBitWidth()) &&
758766
(!getRounding().has_value())) {
759767
return emitError("Rounding mode is required for FP downcast");

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,8 @@ class DecomposeScaledBlocked
495495
assert(type == ScaleDotElemType::E5M2 || type == ScaleDotElemType::E4M3);
496496
auto vTypeBf16 = RankedTensorType::get(
497497
newVType.getShape(), rewriter.getBF16Type(), newVType.getEncoding());
498-
ret = rewriter.create<FpToFpOp>(v.getLoc(), vTypeBf16, ret);
498+
ret = cast<TypedValue<RankedTensorType>>(
499+
rewriter.create<FpToFpOp>(v.getLoc(), vTypeBf16, ret).getResult());
499500
}
500501
if (opt_scale.has_value()) {
501502
auto scale = *opt_scale;

lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1024,7 +1024,7 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
10241024
if (auto fpToFpOp = dyn_cast<FpToFpOp>(op)) {
10251025
auto srcType = cast<RankedTensorType>(fpToFpOp.getOperand().getType());
10261026
return getElementBitWidth(srcType) <
1027-
getElementBitWidth(fpToFpOp.getType());
1027+
getElementBitWidth(cast<RankedTensorType>(fpToFpOp.getType()));
10281028
}
10291029
return false;
10301030
};

test/Triton/canonicalize.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,19 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
6767

6868
// -----
6969

70+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
71+
tt.func @fp_to_fp_pos_zero_fold_scalar() -> f8E4M3FNUZ {
72+
// CHECK-LABEL: fp_to_fp_pos_zero_fold_scalar
73+
// CHECK-NEXT: %[[cst_folded:.+]] = arith.constant 0.000000e+00 : f8E4M3FNUZ
74+
// CHECK-NEXT: tt.return %[[cst_folded]]
75+
%cst = arith.constant 0.00e+00 : f32
76+
%cst_converted = tt.fp_to_fp %cst, rounding = rtne : f32 -> f8E4M3FNUZ
77+
tt.return %cst_converted : f8E4M3FNUZ
78+
}
79+
} // end module
80+
81+
// -----
82+
7083
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
7184
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
7285
tt.func @fp_to_fp_neg_zero_fold() -> tensor<32x128xf8E4M3FN, #blocked> {

third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -580,7 +580,8 @@ class ScaledBlockedToMFMA final : public OpRewritePattern<triton::DotScaledOp> {
580580

581581
auto vTypeBf16 = RankedTensorType::get(
582582
vType.getShape(), rewriter.getBF16Type(), newVEncoding);
583-
return rewriter.create<FpToFpOp>(v.getLoc(), vTypeBf16, v);
583+
return cast<TensorValue>(
584+
rewriter.create<FpToFpOp>(v.getLoc(), vTypeBf16, v).getResult());
584585
};
585586
a = toMMABf16(a, 0, aElemType);
586587
b = toMMABf16(b, 1, bElemType);

0 commit comments

Comments
 (0)