@@ -411,117 +411,6 @@ emitBaseIndexForDpasLayout(Location loc, RewriterBase &rewriter,
411411
412412namespace mlir ::triton::intel {
413413
414- inline SmallVector<SmallVector<unsigned >>
415- emitOffsetForLayout (Attribute layout, RankedTensorType type);
416-
417- // -----------------------------------------------------------------------
418- // Get offsets / indices for any layout
419- // -----------------------------------------------------------------------
420-
421- inline SmallVector<Value>
422- emitBaseIndexForLayoutImpl (Location loc, RewriterBase &rewriter,
423- const TargetInfoBase &target, Attribute layout,
424- RankedTensorType type, bool withCTAOffset) {
425- auto b = TritonLLVMOpBuilder (loc, rewriter);
426- auto shape = type.getShape ();
427-
428- SmallVector<Value> baseIndex;
429- RewriterBase::InsertionGuard guard (rewriter);
430- SmallVector<Value> result;
431- if (auto dpasLayout = dyn_cast<DpasEncodingAttr>(layout)) {
432- result = emitBaseIndexForDpasLayout (loc, rewriter, dpasLayout, type);
433- } else if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(layout)) {
434- auto parentLayout = sliceLayout.getParent ();
435- auto parentShape = sliceLayout.paddedShape (type.getShape ());
436- RankedTensorType parentTy =
437- RankedTensorType::get (parentShape, type.getElementType (), parentLayout);
438- result = ::intel::emitBaseIndexForLayoutImpl (
439- loc, rewriter, target, parentLayout, parentTy, withCTAOffset);
440- result.erase (result.begin () + sliceLayout.getDim ());
441- // CTAOffset has been added in emitBaseIndexForLayout of parentLayout
442- return result;
443- } else if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
444- result = emitBaseIndexForDotOpLayout (loc, rewriter, dotLayout, type);
445- } else {
446- return mlir::emitBaseIndexForLayoutImpl (loc, rewriter, target, layout, type,
447- withCTAOffset);
448- }
449- if (withCTAOffset) {
450- auto CTAOffset =
451- emitCTAOffsetForLayout (loc, rewriter, target, layout, shape);
452- assert (CTAOffset.size () == result.size () && " Rank mismatch" );
453- for (unsigned k = 0 ; k < result.size (); ++k) {
454- // Individual elements of `result` may be null. In the caller
455- // (emitBaseIndexForLayout), we assert that all such dimensions are sliced
456- // off.
457- if (!result[k])
458- continue ;
459- result[k] = b.add (result[k], CTAOffset[k]);
460- }
461- }
462- return result;
463- }
464-
465- inline SmallVector<Value>
466- emitBaseIndexForLayout (Location loc, RewriterBase &rewriter,
467- const TargetInfoBase &target, Attribute layout,
468- RankedTensorType type, bool withCTAOffset) {
469- SmallVector<Value> idx = ::intel::emitBaseIndexForLayoutImpl (
470- loc, rewriter, target, layout, type, withCTAOffset);
471-
472- // Check that any null values were sliced out.
473- for (Value v : idx) {
474- if (!v) {
475- llvm::errs () << " Failed to generate indexing code, possibly due to bad "
476- " #mma layout. Please rerun your program with "
477- " MLIR_ENABLE_DUMP=1 and file a bug."
478- << " \n loc: " << loc << " \n layout: " << layout
479- << " \n type: " << type << " \n withCTAOffset: " << withCTAOffset
480- << " \n " ;
481- llvm::report_fatal_error (" Failed to generate indexing code" );
482- }
483- }
484-
485- return idx;
486- }
487-
488- inline SmallVector<SmallVector<unsigned >>
489- emitOffsetForLayout (Attribute layout, RankedTensorType type) {
490- return mlir::emitOffsetForLayout (layout, type);
491- }
492-
493- // Emit indices calculation within each ConversionPattern, and returns a
494- // [elemsPerThread X rank] index matrix.
495- inline SmallVector<SmallVector<Value>>
496- emitIndices (Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
497- Attribute layout, RankedTensorType type, bool withCTAOffset) {
498- auto b = TritonLLVMOpBuilder (loc, rewriter);
499- MLIRContext *ctx = rewriter.getContext ();
500- auto shape = type.getShape ();
501- std::optional<LinearLayout> ll = triton::gpu::toLinearLayout (shape, layout);
502- if (ll.has_value ())
503- return mlir::emitIndices (loc, rewriter, target, layout, type,
504- withCTAOffset);
505-
506- // step 1, delinearize threadId to get the base index
507- auto multiDimBase = ::intel::emitBaseIndexForLayout (
508- loc, rewriter, target, layout, type, withCTAOffset);
509- // step 2, get offset of each element
510- auto offset = intel::emitOffsetForLayout (layout, type);
511- // step 3, add offset to base, and reorder the sequence
512- // of indices to guarantee that elems in the same
513- // sizePerThread are adjacent in order
514- unsigned rank = shape.size ();
515- unsigned elemsPerThread = offset.size ();
516- SmallVector<SmallVector<Value>> multiDimIdx (elemsPerThread,
517- SmallVector<Value>(rank));
518- for (unsigned n = 0 ; n < elemsPerThread; ++n)
519- for (unsigned k = 0 ; k < rank; ++k)
520- multiDimIdx[n][k] = b.add (multiDimBase[k], b.i32_val (offset[n][k]));
521-
522- return multiDimIdx;
523- }
524-
525414Value convertBf16ToFp32 (Location loc, ConversionPatternRewriter &rewriter,
526415 Value v);
527416Value convertFp32ToBf16 (Location loc, ConversionPatternRewriter &rewriter,
0 commit comments