@@ -220,88 +220,45 @@ static std::tuple<Type, Type, Type> getABCElementTypes(MLIRContext *context,
220220 Type i8 = IntegerType::get (context, 8 );
221221 Type i32 = IntegerType::get (context, 32 );
222222 switch (intrinsic) {
223- case MMAIntrinsic::MFMA_F64_16x16x4_F64: {
223+ case MMAIntrinsic::MFMA_F64_16x16x4_F64:
224224 return {f64 , f64 , f64 };
225- }
226- case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
225+ case MMAIntrinsic::MFMA_F32_16x16x4_F32:
227226 return {f32 , f32 , f32 };
228- }
229- case MMAIntrinsic::MFMA_F32_16x16x16_F16: {
230- return {f16 , f16 , f32 };
231- }
232- case MMAIntrinsic::MFMA_F32_32x32x8_F16: {
227+ case MMAIntrinsic::MFMA_F32_16x16x16_F16:
228+ case MMAIntrinsic::MFMA_F32_32x32x8_F16:
229+ case MMAIntrinsic::WMMA_F32_16x16x16_F16:
230+ case MMAIntrinsic::NV_WMMA_F32_16x16x16_F16:
233231 return {f16 , f16 , f32 };
234- }
235- case MMAIntrinsic::MFMA_F32_16x16x8_BF16: {
236- return {bf16 , bf16 , f32 };
237- }
238- case MMAIntrinsic::MFMA_F32_32x32x4_BF16: {
239- return {bf16 , bf16 , f32 };
240- }
241- case MMAIntrinsic::MFMA_F32_16x16x16_BF16: {
242- return {bf16 , bf16 , f32 };
243- }
244- case MMAIntrinsic::MFMA_F32_32x32x8_BF16: {
232+ case MMAIntrinsic::WMMA_F16_16x16x16_F16:
233+ case MMAIntrinsic::NV_WMMA_F16_16x16x16_F16:
234+ return {f16 , f16 , f16 };
235+ case MMAIntrinsic::MFMA_F32_16x16x8_BF16:
236+ case MMAIntrinsic::MFMA_F32_32x32x4_BF16:
237+ case MMAIntrinsic::MFMA_F32_16x16x16_BF16:
238+ case MMAIntrinsic::MFMA_F32_32x32x8_BF16:
239+ case MMAIntrinsic::WMMA_F32_16x16x16_BF16:
245240 return {bf16 , bf16 , f32 };
246- }
247- case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: {
248- return {f8E4M3FNUZ, f8E4M3FNUZ, f32 };
249- }
250- case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: {
251- return {f8E5M2FNUZ, f8E5M2FNUZ, f32 };
252- }
253- case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ: {
254- return {f8E4M3FNUZ, f8E5M2FNUZ, f32 };
255- }
256- case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ: {
257- return {f8E5M2FNUZ, f8E4M3FNUZ, f32 };
258- }
259- case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ: {
241+ case MMAIntrinsic::WMMA_BF16_16x16x16_BF16:
242+ return {bf16 , bf16 , bf16 };
243+ case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
244+ case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ:
260245 return {f8E4M3FNUZ, f8E4M3FNUZ, f32 };
261- }
262- case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ: {
246+ case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
247+ case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ:
263248 return {f8E5M2FNUZ, f8E5M2FNUZ, f32 };
264- }
265- case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ: {
249+ case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ:
250+ case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ:
266251 return {f8E4M3FNUZ, f8E5M2FNUZ, f32 };
267- }
268- case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ: {
252+ case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ:
253+ case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ:
269254 return {f8E5M2FNUZ, f8E4M3FNUZ, f32 };
270- }
271- case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
272- return {i8 , i8 , i32 };
273- }
274- case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
275- return {i8 , i8 , i32 };
276- }
277- case MMAIntrinsic::MFMA_I32_32x32x8_I8: {
278- return {i8 , i8 , i32 };
279- }
280- case MMAIntrinsic::MFMA_I32_16x16x16_I8: {
281- return {i8 , i8 , i32 };
282- }
283- case MMAIntrinsic::WMMA_F32_16x16x16_F16: {
284- return {f16 , f16 , f32 };
285- }
286- case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
287- return {f16 , f16 , f16 };
288- }
289- case MMAIntrinsic::WMMA_F32_16x16x16_BF16: {
290- return {bf16 , bf16 , f32 };
291- }
292- case MMAIntrinsic::WMMA_BF16_16x16x16_BF16: {
293- return {bf16 , bf16 , bf16 };
294- }
295- case MMAIntrinsic::WMMA_I32_16x16x16_I8: {
255+ case MMAIntrinsic::MFMA_I32_16x16x16_I8:
256+ case MMAIntrinsic::MFMA_I32_32x32x8_I8:
257+ case MMAIntrinsic::MFMA_I32_16x16x32_I8:
258+ case MMAIntrinsic::MFMA_I32_32x32x16_I8:
259+ case MMAIntrinsic::WMMA_I32_16x16x16_I8:
296260 return {i8 , i8 , i32 };
297261 }
298- case MMAIntrinsic::NV_WMMA_F16_16x16x16_F16: {
299- return {f16 , f16 , f16 };
300- }
301- case MMAIntrinsic::NV_WMMA_F32_16x16x16_F16: {
302- return {f16 , f16 , f32 };
303- }
304- }
305262 assert (false && " unexpected enum value" );
306263 return {};
307264}
@@ -498,11 +455,15 @@ MMAAttr::getContractionLayout(vector::ContractionOp contract) const {
498455 return IREE::GPU::getContractionLayout (contract, layout);
499456}
500457
501- int64_t MMAAttr:: getBlockSize () const {
458+ static int getBlockSize (MMAIntrinsic /* intrinsic */ ) {
502459 // Not supporting any block size other than 1 at the moment.
503460 return 1 ;
504461}
505462
463+ int64_t MMAAttr::getBlockSize () const {
464+ return IREE::GPU::getBlockSize (getIntrinsic ().getValue ());
465+ }
466+
506467static uint32_t getArchID (MMAIntrinsic intrinsic) {
507468 return static_cast <int >(intrinsic) & 0xFF00 ;
508469}
@@ -704,6 +665,31 @@ SmallVector<VirtualMMAIntrinsic> MMAAttr::getVirtualIntrinsics() const {
704665 }
705666}
706667
668+ static Value createMmaOp (OpBuilder &builder, Location loc,
669+ MMAIntrinsic intrinsic, Type resultType, Value lhs,
670+ Value rhs, Value acc) {
671+ auto getVecOrSingleElem = [&](Value vec) -> Value {
672+ bool one = llvm::cast<VectorType>(vec.getType ()).getNumElements () == 1 ;
673+ return one ? builder.create <vector::ExtractOp>(loc, vec, 0 ) : vec;
674+ };
675+ auto layout = getOpaqueMMALayout (builder.getContext (), intrinsic);
676+ if (is_AMD_MFMA (intrinsic)) {
677+ // MFMA intrinsics want single-element operands of element type, not vector.
678+ lhs = getVecOrSingleElem (lhs);
679+ rhs = getVecOrSingleElem (rhs);
680+ return builder
681+ .create <amdgpu::MFMAOp>(loc, resultType, layout.mSize , layout.nSize ,
682+ layout.kSize , getBlockSize (intrinsic), lhs, rhs,
683+ acc)
684+ .getResult ();
685+ }
686+ if (is_AMD_WMMA (intrinsic)) {
687+ return builder.create <amdgpu::WMMAOp>(loc, resultType, lhs, rhs, acc)
688+ .getResult ();
689+ }
690+ return {};
691+ }
692+
707693// Generates amdgpu.mfma/wmma operation on the given inputs for this attribute
708694// type.
709695FailureOr<Value> MMAAttr::buildMmaOperation (OpBuilder &builder, Location loc,
@@ -719,23 +705,9 @@ FailureOr<Value> MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc,
719705 if (cType != resultType) {
720706 return failure ();
721707 }
722- auto getVecOrSingleElem = [&](Value vec) -> Value {
723- bool one = llvm::cast<VectorType>(vec.getType ()).getNumElements () == 1 ;
724- return one ? builder.create <vector::ExtractOp>(loc, vec, 0 ) : vec;
725- };
726- MMAIntrinsic intrinsic = getIntrinsic ().getValue ();
727- if (is_AMD_MFMA (intrinsic)) {
728- // MFMA intrinsics want single-element operands of element type, not vector.
729- lhs = getVecOrSingleElem (lhs);
730- rhs = getVecOrSingleElem (rhs);
731- auto [m, n, k] = getMNKShape ();
732- return builder
733- .create <amdgpu::MFMAOp>(loc, resultType, m, n, k, getBlockSize (), lhs,
734- rhs, acc)
735- .getResult ();
736- } else if (is_AMD_WMMA (intrinsic)) {
737- return builder.create <amdgpu::WMMAOp>(loc, resultType, lhs, rhs, acc)
738- .getResult ();
708+ if (Value value = createMmaOp (builder, loc, getIntrinsic ().getValue (),
709+ resultType, lhs, rhs, acc)) {
710+ return value;
739711 }
740712 return failure ();
741713}
@@ -1168,23 +1140,18 @@ FailureOr<Value> DataTiledMMAAttr::buildMmaOperation(OpBuilder &builder,
11681140 SmallVector<Value> intrinsicsAcc =
11691141 distributeMmaFragmentToIntrinsics (builder, loc, acc, accSwizzle);
11701142
1171- // Get a MMAAttr for the intrinsic itself, to reuse MMAAttr::buildMmaOperation
1172- // to create the target intrinsics.
1173- auto intrinsicMma = MMAAttr::get (getContext (), getIntrinsic ().getValue ());
1174- auto [intrinsicAType, intrinsicBType, intrinsicCType] =
1175- intrinsicMma.getABCVectorTypes ();
1143+ MMAIntrinsic intrinsic = getIntrinsic ().getValue ();
1144+ VectorType intrinCType =
1145+ getVectorType (builder.getContext (), intrinsic, MMAFragment::Acc);
11761146
11771147 // Loop over the 3 unroll_{m,n,k} dimensions to create the intrinsics.
11781148 for (int mu = 0 ; mu < getUnrollM (); ++mu) {
11791149 for (int nu = 0 ; nu < getUnrollN (); ++nu) {
11801150 for (int ku = 0 ; ku < getUnrollK (); ++ku) {
1181- // Assume intrinsicMma.buildMmaOperation() success: validation should be
1182- // completed prior to mutating IR.
11831151 Value lhs = intrinsicsLhs[mu * getUnrollK () + ku];
11841152 Value rhs = intrinsicsRhs[nu * getUnrollK () + ku];
11851153 Value &acc = intrinsicsAcc[mu * getUnrollN () + nu];
1186- acc = *intrinsicMma.buildMmaOperation (builder, loc, intrinsicCType, lhs,
1187- rhs, acc);
1154+ acc = createMmaOp (builder, loc, intrinsic, intrinCType, lhs, rhs, acc);
11881155 }
11891156 }
11901157 }
0 commit comments