|
19 | 19 | using namespace mlir; |
20 | 20 | using namespace mlir::triton; |
21 | 21 | using namespace mlir::triton::gpu; |
| 22 | +using ::mlir::LLVM::AMD::upcast4xMxfp8_HW; |
| 23 | +using ::mlir::LLVM::AMD::upcast8xMxfp4_HW; |
22 | 24 | using ::mlir::LLVM::AMD::upcast8xMxfp4_SW; |
23 | 25 |
|
24 | 26 | namespace { |
@@ -85,54 +87,6 @@ Value mxfpScaleBf16ViaF32(RewriterBase &rewriter, Location loc, Value v, |
85 | 87 | return b.select(scaleIsNan, nanBf16, mulBf16); |
86 | 88 | }; |
87 | 89 |
|
88 | | -template <typename ConvertOp> |
89 | | -SmallVector<Value, 4> upcast8xMxfp4_HW(RewriterBase &rewriter, Location loc, |
90 | | - ArrayRef<Value> xVals, int idx, |
91 | | - Value scale) { |
92 | | - auto b = TritonLLVMOpBuilder(loc, rewriter); |
93 | | - Value packedVec = b.undef(vec_ty(i8_ty, 4)); |
94 | | - for (int i : llvm::seq(4)) |
95 | | - packedVec = b.insert_element(packedVec, xVals[idx + i], b.i32_val(i)); |
96 | | - packedVec = b.bitcast(packedVec, i32_ty); |
97 | | - Type retElemType = bf16_ty; |
98 | | - if constexpr (std::is_same_v<ConvertOp, ROCDL::CvtScaleF32PkF16Fp4Op>) |
99 | | - retElemType = f16_ty; |
100 | | - Type resType = vec_ty(retElemType, 2); |
101 | | - Value scaleF32 = |
102 | | - b.bitcast(b.shl(b.zext(i32_ty, scale), b.i32_val(23)), f32_ty); |
103 | | - SmallVector<Value, 4> results; |
104 | | - for (int srcSelIndex : llvm::seq(4)) |
105 | | - results.push_back(rewriter.create<ConvertOp>(loc, resType, packedVec, |
106 | | - scaleF32, srcSelIndex)); |
107 | | - return results; |
108 | | -} |
109 | | - |
110 | | -template <typename ConvertOp> |
111 | | -SmallVector<Value, 2> upcast4xMxfp8_HW(RewriterBase &rewriter, Location loc, |
112 | | - ArrayRef<Value> xVals, int idx, |
113 | | - Value scale) { |
114 | | - auto b = TritonLLVMOpBuilder(loc, rewriter); |
115 | | - Value packedVec = b.undef(vec_ty(i8_ty, 4)); |
116 | | - for (int i : llvm::seq(4)) |
117 | | - packedVec = b.insert_element(packedVec, xVals[idx + i], b.i32_val(i)); |
118 | | - packedVec = b.bitcast(packedVec, i32_ty); |
119 | | - Type retElemType = bf16_ty; |
120 | | - if constexpr (std::is_same_v<ConvertOp, ROCDL::CvtScaleF32PkF16Fp8Op> || |
121 | | - std::is_same_v<ConvertOp, ROCDL::CvtScaleF32PkF16Bf8Op>) |
122 | | - retElemType = f16_ty; |
123 | | - Type resType = vec_ty(retElemType, 2); |
124 | | - Value scaleF32 = |
125 | | - b.bitcast(b.shl(b.zext(i32_ty, scale), b.i32_val(23)), f32_ty); |
126 | | - SmallVector<Value, 2> results; |
127 | | - results.push_back(rewriter.create<ConvertOp>(loc, resType, packedVec, |
128 | | - scaleF32, |
129 | | - /*srcLoHiSel=*/false)); |
130 | | - results.push_back(rewriter.create<ConvertOp>(loc, resType, packedVec, |
131 | | - scaleF32, |
132 | | - /*srcLoHiSel=*/true)); |
133 | | - return results; |
134 | | -} |
135 | | - |
136 | 90 | // Upcast 8 mxfp4 values from xVals starting at idx using the given scale |
137 | 91 | // factor, and store the results into yVals |
138 | 92 | static void upcast8xMxfp4(RewriterBase &rewriter, Location loc, |
|
0 commit comments