@@ -11,138 +11,23 @@ using namespace mlir::triton::gpu;
1111
1212namespace mlir ::triton::gpu {
1313
14- namespace {
15-
16- bool isDotOpTensorAndPacked (Type srcTy) {
17- auto tensorTy = dyn_cast<RankedTensorType>(srcTy);
18- if (!tensorTy)
19- return false ;
20- auto encoding = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding ());
21- if (!encoding)
22- return false ;
23- auto parentEnc = dyn_cast<NvidiaMmaEncodingAttr>(encoding.getParent ());
24- // By code convention, values for Hopper's dotOp-encoded tensors are not
25- // packed
26- if (!parentEnc || parentEnc.isHopper ())
27- return false ;
28- return true ;
29- }
30-
31- } // namespace
32-
3314Type getElementType (Value value) {
3415 auto type = value.getType ();
3516 if (auto tensorType = dyn_cast<RankedTensorType>(type))
3617 return tensorType.getElementType ();
3718 return type;
3819}
39- // MMA encoding has a different order depending on the element's bit width;
40- // reorder if we're in this case.
41- SmallVector<Value> reorderValues (const SmallVector<Value> &values, Type inType,
42- Type ouType) {
43- auto inTensorTy = dyn_cast<RankedTensorType>(inType);
44- auto ouTensorTy = dyn_cast<RankedTensorType>(ouType);
45- if (!inTensorTy || !ouTensorTy)
46- return values;
47- auto inEncoding = dyn_cast<DotOperandEncodingAttr>(inTensorTy.getEncoding ());
48- auto ouEncoding = dyn_cast<DotOperandEncodingAttr>(ouTensorTy.getEncoding ());
49- assert (inEncoding == ouEncoding);
50- if (!inEncoding)
51- return values;
52- // If the parent of the dot operand is in block encoding, we don't need to
53- // reorder elements
54- auto parentEncoding = dyn_cast<NvidiaMmaEncodingAttr>(ouEncoding.getParent ());
55- if (!parentEncoding || parentEncoding.isHopper ())
56- return values;
57- size_t inBitWidth = inTensorTy.getElementType ().getIntOrFloatBitWidth ();
58- size_t ouBitWidth = ouTensorTy.getElementType ().getIntOrFloatBitWidth ();
59- auto ouEltTy = ouTensorTy.getElementType ();
60- if (inBitWidth == ouBitWidth)
61- return values;
62- if (inBitWidth == 16 && ouBitWidth == 32 ) {
63- // Register layout conversion:
64- //
65- // [0, 1], [4, 5] ⟶ [0], [1], [4], [5]
66- // [2, 3], [6, 7] [2], [3], [6], [7]
67- //
68- // Original access order:
69- //
70- // [0, 1], [2, 3], [4, 5], [6, 7]
71- //
72- // Transformed access order:
73- //
74- // [0], [2], [1], [3], [4], [6], [5], [7]
75- SmallVector<Value> ret;
76- for (unsigned i = 0 ; i < values.size (); i += 8 ) {
77- ret.push_back (values[i]);
78- ret.push_back (values[i + 2 ]);
79- ret.push_back (values[i + 1 ]);
80- ret.push_back (values[i + 3 ]);
81- ret.push_back (values[i + 4 ]);
82- ret.push_back (values[i + 6 ]);
83- ret.push_back (values[i + 5 ]);
84- ret.push_back (values[i + 7 ]);
85- }
86- return ret;
87- }
88- if (inBitWidth == 8 && ouBitWidth == 16 ) {
89- // Register layout conversion:
90- //
91- // [0, 1, 2, 3], [8, 9, 10, 11] ⟶ [0, 1], [2, 3], [8, 9], [10, 11]
92- // [4, 5, 6, 7], [12, 13, 14, 15] [4, 5], [6, 7], [12, 13], [14, 15]
93- //
94- // Original access order:
95- //
96- // [0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]
97- //
98- // Transformed access order:
99- //
100- // [0, 1], [4, 5], [2, 3], [6, 7], [8, 9], [12, 13], [10, 11], [14, 15]
101- SmallVector<Value> ret;
102- for (unsigned i = 0 ; i < values.size (); i += 16 ) {
103- ret.push_back (values[i]);
104- ret.push_back (values[i + 1 ]);
105- ret.push_back (values[i + 4 ]);
106- ret.push_back (values[i + 5 ]);
107- ret.push_back (values[i + 2 ]);
108- ret.push_back (values[i + 3 ]);
109- ret.push_back (values[i + 6 ]);
110- ret.push_back (values[i + 7 ]);
111- ret.push_back (values[i + 8 ]);
112- ret.push_back (values[i + 9 ]);
113- ret.push_back (values[i + 12 ]);
114- ret.push_back (values[i + 13 ]);
115- ret.push_back (values[i + 10 ]);
116- ret.push_back (values[i + 11 ]);
117- ret.push_back (values[i + 14 ]);
118- ret.push_back (values[i + 15 ]);
119- }
120- return ret;
121- }
122- llvm_unreachable (" unimplemented code path" );
123- }
12420
12521int getNumElementsPerThreads (Type type,
12622 const LLVMTypeConverter *typeConverter) {
12723 int numElemsPerThread = 1 ;
128- auto tensorTy = dyn_cast<RankedTensorType>(type);
129- if (!tensorTy)
130- return numElemsPerThread;
131- auto structType =
132- dyn_cast<LLVM::LLVMStructType>(typeConverter->convertType (type));
133- if (structType) {
134- numElemsPerThread = structType.getBody ().size ();
24+ if (auto tensorTy = dyn_cast<RankedTensorType>(type)) {
25+ auto structType =
26+ dyn_cast<LLVM::LLVMStructType>(typeConverter->convertType (type));
27+ if (structType)
28+ numElemsPerThread = structType.getBody ().size ();
13529 }
136- auto encoding = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding ());
137- if (!(encoding && isa<NvidiaMmaEncodingAttr>(encoding.getParent ())))
138- return numElemsPerThread;
139- auto eltType = tensorTy.getElementType ();
140- assert (eltType.getIntOrFloatBitWidth () <= 32 &&
141- " Only support element type with bit width <= 32 in dot operand mma "
142- " layout" );
143- // dot operand data are packed into i32 elements so use the following formula
144- // to get the number of elements per thread.
145- return (32 / eltType.getIntOrFloatBitWidth ()) * numElemsPerThread;
30+ return numElemsPerThread;
14631}
14732
14833} // namespace mlir::triton::gpu
@@ -473,8 +358,7 @@ struct ElementwiseInlineAsmOpConversion
473358 for (auto operand : adaptor.getOperands ()) {
474359 auto argTy = op->getOperand (0 ).getType ();
475360 auto subOperands = unpackLLElements (loc, operand, rewriter);
476- unpackedOperands.push_back (
477- unpackI32s (subOperands, argTy, rewriter, loc, getTypeConverter ()));
361+ unpackedOperands.push_back (subOperands);
478362 }
479363
480364 int numElemsPerThread = getNumElementsPerThreads (op->getResult (0 ).getType (),
@@ -527,16 +411,6 @@ struct ElementwiseInlineAsmOpConversion
527411 // Reorder and pack the results.
528412 SmallVector<Value> outs;
529413 for (int i = 0 ; i < unpackedResults.size (); i++) {
530- // We reordered all the inputs so they match operand 0. Reorder the
531- // outputs accordingly.
532- if (op->getNumOperands () > 0 ) {
533- unpackedResults[i] = reorderValues (
534- unpackedResults[i], /* inType=*/ op->getOperand (0 ).getType (),
535- /* ouType=*/ op->getResult (i).getType ());
536- }
537- auto dstTy = op->getResult (i).getType ();
538- unpackedResults[i] = packI32s (unpackedResults[i], dstTy, rewriter, loc,
539- getTypeConverter ());
540414 outs.push_back (packLLElements (loc, getTypeConverter (), unpackedResults[i],
541415 rewriter, op->getResult (i).getType ()));
542416 }
0 commit comments