Skip to content

Commit ddada27

Browse files
authored
[AMD][NFC] Expose Hardware Upcasting Utilities (#8076)
Exposed hardware upcasting utilities to be used by other passes. This is one of a series of PRs to decompose scaled dot on AMD backend.
1 parent c52137f commit ddada27

File tree

2 files changed

+50
-48
lines changed

2 files changed

+50
-48
lines changed

third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp

Lines changed: 2 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
using namespace mlir;
2020
using namespace mlir::triton;
2121
using namespace mlir::triton::gpu;
22+
using ::mlir::LLVM::AMD::upcast4xMxfp8_HW;
23+
using ::mlir::LLVM::AMD::upcast8xMxfp4_HW;
2224
using ::mlir::LLVM::AMD::upcast8xMxfp4_SW;
2325

2426
namespace {
@@ -85,54 +87,6 @@ Value mxfpScaleBf16ViaF32(RewriterBase &rewriter, Location loc, Value v,
8587
return b.select(scaleIsNan, nanBf16, mulBf16);
8688
};
8789

88-
template <typename ConvertOp>
89-
SmallVector<Value, 4> upcast8xMxfp4_HW(RewriterBase &rewriter, Location loc,
90-
ArrayRef<Value> xVals, int idx,
91-
Value scale) {
92-
auto b = TritonLLVMOpBuilder(loc, rewriter);
93-
Value packedVec = b.undef(vec_ty(i8_ty, 4));
94-
for (int i : llvm::seq(4))
95-
packedVec = b.insert_element(packedVec, xVals[idx + i], b.i32_val(i));
96-
packedVec = b.bitcast(packedVec, i32_ty);
97-
Type retElemType = bf16_ty;
98-
if constexpr (std::is_same_v<ConvertOp, ROCDL::CvtScaleF32PkF16Fp4Op>)
99-
retElemType = f16_ty;
100-
Type resType = vec_ty(retElemType, 2);
101-
Value scaleF32 =
102-
b.bitcast(b.shl(b.zext(i32_ty, scale), b.i32_val(23)), f32_ty);
103-
SmallVector<Value, 4> results;
104-
for (int srcSelIndex : llvm::seq(4))
105-
results.push_back(rewriter.create<ConvertOp>(loc, resType, packedVec,
106-
scaleF32, srcSelIndex));
107-
return results;
108-
}
109-
110-
template <typename ConvertOp>
111-
SmallVector<Value, 2> upcast4xMxfp8_HW(RewriterBase &rewriter, Location loc,
112-
ArrayRef<Value> xVals, int idx,
113-
Value scale) {
114-
auto b = TritonLLVMOpBuilder(loc, rewriter);
115-
Value packedVec = b.undef(vec_ty(i8_ty, 4));
116-
for (int i : llvm::seq(4))
117-
packedVec = b.insert_element(packedVec, xVals[idx + i], b.i32_val(i));
118-
packedVec = b.bitcast(packedVec, i32_ty);
119-
Type retElemType = bf16_ty;
120-
if constexpr (std::is_same_v<ConvertOp, ROCDL::CvtScaleF32PkF16Fp8Op> ||
121-
std::is_same_v<ConvertOp, ROCDL::CvtScaleF32PkF16Bf8Op>)
122-
retElemType = f16_ty;
123-
Type resType = vec_ty(retElemType, 2);
124-
Value scaleF32 =
125-
b.bitcast(b.shl(b.zext(i32_ty, scale), b.i32_val(23)), f32_ty);
126-
SmallVector<Value, 2> results;
127-
results.push_back(rewriter.create<ConvertOp>(loc, resType, packedVec,
128-
scaleF32,
129-
/*srcLoHiSel=*/false));
130-
results.push_back(rewriter.create<ConvertOp>(loc, resType, packedVec,
131-
scaleF32,
132-
/*srcLoHiSel=*/true));
133-
return results;
134-
}
135-
13690
// Upcast 8 mxfp4 values from xVals starting at idx using the given scale
13791
// factor, and store the results into yVals
13892
static void upcast8xMxfp4(RewriterBase &rewriter, Location loc,

third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,54 @@ bool isChainDotTail(mlir::triton::DotOpInterface dotOp);
127127
// to a wider type: BF16 or FP16
128128
SmallVector<Value, 4> upcast8xMxfp4_SW(RewriterBase &rewriter, Operation *op,
129129
bool toFp16, Value packedVec);
130+
131+
template <typename ConvertOp>
132+
SmallVector<Value, 4> upcast8xMxfp4_HW(RewriterBase &rewriter, Location loc,
133+
ArrayRef<Value> xVals, int idx,
134+
Value scale) {
135+
auto b = TritonLLVMOpBuilder(loc, rewriter);
136+
Value packedVec = b.undef(vec_ty(i8_ty, 4));
137+
for (int i : llvm::seq(4))
138+
packedVec = b.insert_element(packedVec, xVals[idx + i], b.i32_val(i));
139+
packedVec = b.bitcast(packedVec, i32_ty);
140+
Type retElemType = bf16_ty;
141+
if constexpr (std::is_same_v<ConvertOp, ROCDL::CvtScaleF32PkF16Fp4Op>)
142+
retElemType = f16_ty;
143+
Type resType = vec_ty(retElemType, 2);
144+
Value scaleF32 =
145+
b.bitcast(b.shl(b.zext(i32_ty, scale), b.i32_val(23)), f32_ty);
146+
SmallVector<Value, 4> results;
147+
for (int srcSelIndex : llvm::seq(4))
148+
results.push_back(rewriter.create<ConvertOp>(loc, resType, packedVec,
149+
scaleF32, srcSelIndex));
150+
return results;
151+
}
152+
153+
template <typename ConvertOp>
154+
SmallVector<Value, 2> upcast4xMxfp8_HW(RewriterBase &rewriter, Location loc,
155+
ArrayRef<Value> xVals, int idx,
156+
Value scale) {
157+
auto b = TritonLLVMOpBuilder(loc, rewriter);
158+
Value packedVec = b.undef(vec_ty(i8_ty, 4));
159+
for (int i : llvm::seq(4))
160+
packedVec = b.insert_element(packedVec, xVals[idx + i], b.i32_val(i));
161+
packedVec = b.bitcast(packedVec, i32_ty);
162+
Type retElemType = bf16_ty;
163+
if constexpr (std::is_same_v<ConvertOp, ROCDL::CvtScaleF32PkF16Fp8Op> ||
164+
std::is_same_v<ConvertOp, ROCDL::CvtScaleF32PkF16Bf8Op>)
165+
retElemType = f16_ty;
166+
Type resType = vec_ty(retElemType, 2);
167+
Value scaleF32 =
168+
b.bitcast(b.shl(b.zext(i32_ty, scale), b.i32_val(23)), f32_ty);
169+
SmallVector<Value, 2> results;
170+
results.push_back(rewriter.create<ConvertOp>(loc, resType, packedVec,
171+
scaleF32,
172+
/*srcLoHiSel=*/false));
173+
results.push_back(rewriter.create<ConvertOp>(loc, resType, packedVec,
174+
scaleF32,
175+
/*srcLoHiSel=*/true));
176+
return results;
177+
}
130178
} // namespace mlir::LLVM::AMD
131179

132180
#endif // TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_UTILITY_H_

0 commit comments

Comments
 (0)