@@ -272,13 +272,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
272
272
const LinearLayout &dstLayout,
273
273
OpAdaptor adaptor,
274
274
ConversionPatternRewriter &rewriter) const {
275
- MLIRContext *ctx = op.getContext ();
276
- auto loc = op.getLoc ();
277
- auto b = TritonLLVMOpBuilder (loc, rewriter);
278
- auto srcTy = op.getSrc ().getType ();
279
- auto dstTy = op.getType ();
280
-
281
- assert (cvtNeedsSharedMemory (srcTy, dstTy));
275
+ assert (cvtNeedsSharedMemory (op.getSrc ().getType (), op.getType ()));
282
276
283
277
// Try to use swizzling to implement the conversion
284
278
// HACK Remove once AMD tests pass for the swizzling path
@@ -287,52 +281,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
287
281
return success ();
288
282
}
289
283
290
- SmallVector<Value> inVals =
291
- unpackLLElements (loc, adaptor.getSrc (), rewriter);
292
- assert (!inVals.empty ());
293
-
294
- // We munge the input values by converting i<n> (n<8) elements to i8 and
295
- // pointers to i64. This is necessary because TargetInfo::loadDShared and
296
- // storeDShared can't handle vectors of pointers or sub-byte elements.
297
- auto elemTy = srcTy.getElementType ();
298
- auto isSubByteInt =
299
- elemTy.isInteger () && elemTy.getIntOrFloatBitWidth () < 8 ;
300
- auto isPtr = isa<triton::PointerType>(elemTy);
301
- auto llvmElemTyOrig = getTypeConverter ()->convertType (elemTy);
302
- if (isSubByteInt)
303
- elemTy = IntegerType::get (elemTy.getContext (), 8 );
304
- else if (isPtr)
305
- elemTy = IntegerType::get (elemTy.getContext (), 64 );
306
- auto llvmElemTy = getTypeConverter ()->convertType (elemTy);
307
-
308
- // Munge input values
309
- for (const auto &it : llvm::enumerate (inVals)) {
310
- if (isSubByteInt) {
311
- inVals[it.index ()] = b.zext (llvmElemTy, it.value ());
312
- } else if (isPtr) {
313
- inVals[it.index ()] = b.ptrtoint (llvmElemTy, it.value ());
314
- }
315
- }
316
-
317
- // Pretty sure this is the identity function ATM
318
- // It'd be better to simply call `quotient({kBlock})` and
319
- // remove kBlock from transferWithinBlockImpl
320
- auto srcLayoutWithinBlock = getLayoutWithinBlock (srcLayout);
321
- auto dstLayoutWithinBlock = getLayoutWithinBlock (dstLayout);
322
- SmallVector<Value> outVals = transferWithinBlockImpl (
323
- inVals, op, srcLayoutWithinBlock, dstLayoutWithinBlock, rewriter);
324
-
325
- // Unmunge output values
326
- for (const auto &it : llvm::enumerate (outVals)) {
327
- if (isSubByteInt) {
328
- outVals[it.index ()] = b.trunc (llvmElemTyOrig, it.value ());
329
- } else if (isPtr) {
330
- outVals[it.index ()] = b.inttoptr (llvmElemTyOrig, it.value ());
331
- }
332
- }
284
+ Value result = transferWithinBlockPadding (op, adaptor.getSrc (), targetInfo,
285
+ getTypeConverter (), rewriter);
333
286
334
- Value result = packLLElements (loc, getTypeConverter (), outVals, rewriter,
335
- op.getType ());
336
287
rewriter.replaceOp (op, result);
337
288
return success ();
338
289
}
@@ -343,223 +294,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
343
294
DecomposedWarpConversion decomposed,
344
295
OpAdaptor adaptor,
345
296
ConversionPatternRewriter &rewriter) const ;
346
-
347
- SmallVector<Value>
348
- transferWithinBlockImpl (ArrayRef<Value> inVals, ConvertLayoutOp op,
349
- const LinearLayout &srcLayout,
350
- const LinearLayout &dstLayout,
351
- ConversionPatternRewriter &rewriter) const {
352
- MLIRContext *ctx = op.getContext ();
353
- auto loc = op.getLoc ();
354
- auto b = TritonLLVMOpBuilder (loc, rewriter);
355
-
356
- RankedTensorType srcTy = op.getSrc ().getType ();
357
- auto srcElemTy = srcTy.getElementType ();
358
- const bool isInt1 = srcElemTy.isInteger (1 );
359
-
360
- StringAttr kRegister = str_attr (" register" );
361
- StringAttr kLane = str_attr (" lane" );
362
- StringAttr kWarp = str_attr (" warp" );
363
- StringAttr kBlock = str_attr (" block" );
364
- StringAttr kOffset = str_attr (" offset" );
365
- StringAttr kIteration = str_attr (" iteration" );
366
-
367
- auto [laneId, warpId] = getLaneAndWarpId (rewriter, loc);
368
-
369
- auto scratchConfig =
370
- getScratchConfigForCvt (op.getSrc ().getType (), op.getType ());
371
- auto tensorShapePerCTA = convertType<unsigned , int64_t >(getShapePerCTA (
372
- op.getSrc ().getType ().getEncoding (), op.getType ().getShape ()));
373
- // Input dims: [offset, iteration, block]
374
- // Output dims: dimN-1, dimN-2, ..., dim0, where N is obtained from repShape
375
- LinearLayout sharedLayout = chooseShemLayoutForRegToRegConversion (
376
- ctx, tensorShapePerCTA, scratchConfig.repShape , scratchConfig.order );
377
-
378
- // Layout for the store from registers to shared memory.
379
- //
380
- // Note: If two threads in the same warp write to the same shmem offset, the
381
- // hardware resolves that without a stall or a bank conflict. Therefore we
382
- // don't need to avoid duplicate writes.
383
- // Input dims: [reg, lane, warp]
384
- // Output dims: [offset, iteration]
385
- bool isStMatrix = targetInfo.canUseStMatrix (
386
- op.getSrc ().getType (), scratchConfig.repShape ,
387
- scratchConfig.paddedRepShape , scratchConfig.order ,
388
- /* swizzleByteSize=*/ 0 );
389
- LinearLayout shmemStoreLayout =
390
- isStMatrix ? chooseStMatrixLayout (ctx, op.getSrc ().getType (),
391
- /* swizzleByteSize=*/ 0 )
392
- : srcLayout.invertAndCompose (sharedLayout);
393
-
394
- const int shmemAllocatedNumElems =
395
- getNumScratchElements (scratchConfig.paddedRepShape );
396
- assert (shmemStoreLayout.getOutDimSize (kOffset ) <= shmemAllocatedNumElems);
397
-
398
- // Layout for the load from shmem to registers.
399
- LinearLayout shmemLoadLayout = dstLayout.invertAndCompose (sharedLayout);
400
-
401
- // Check that the `register` fully determines the `iteration`. That is,
402
- // each thread does exactly the same reads and writes to shmem on each
403
- // iteration, just with different input/output registers.
404
- assert (
405
- shmemStoreLayout.sublayoutIsZero ({kLane , kWarp , kBlock }, {kIteration }));
406
- assert (
407
- shmemLoadLayout.sublayoutIsZero ({kLane , kWarp , kBlock }, {kIteration }));
408
-
409
- // iteration -> registers
410
- SmallVector<SmallVector<int >> inRegsForIter =
411
- collectRegsForIter (ctx, shmemStoreLayout);
412
- SmallVector<SmallVector<int >> outRegsForIter =
413
- collectRegsForIter (ctx, shmemLoadLayout);
414
-
415
- Value smemBase =
416
- LLVM::getSharedMemoryBase (loc, rewriter, targetInfo, op.getOperation ());
417
- auto sharedPtrTy = smemBase.getType ();
418
- Type elemTy = inVals[0 ].getType ();
419
- auto outSize = shmemLoadLayout.getInDimSize (kRegister );
420
- auto iterations = sharedLayout.getInDimSize (kIteration );
421
- assert (scratchConfig.inVec * iterations <= inVals.size ());
422
- assert (scratchConfig.outVec * iterations <= outSize);
423
-
424
- // Check only one dimension has been padded.
425
- // This means the difference between the padded shape and the original shape
426
- // should only be in one dimension, specifically in
427
- // `scratchConfig.order[0]`.
428
- auto rank = scratchConfig.repShape .size ();
429
- for (auto i = 0 ; i < rank; i++) {
430
- if (i == scratchConfig.order [0 ]) {
431
- continue ;
432
- }
433
- assert (scratchConfig.repShape [i] == scratchConfig.paddedRepShape [i]);
434
- }
435
- auto paddedStride = scratchConfig.repShape [scratchConfig.order [0 ]];
436
- auto paddedSize =
437
- scratchConfig.paddedRepShape [scratchConfig.order [0 ]] - paddedStride;
438
-
439
- // Linear layout function is split in two parts below:
440
- //
441
- // L(r, t, w, b) = L(0, t, w, b) xor L(r, 0, 0, 0)
442
- // offset = regBase xor regIdx
443
- //
444
- // It is the same hack as what we've done in the emitIndices function to get
445
- // around performance issues on AMD GPUs
446
- auto getVecAddr = [&](LinearLayout &layout, Value ®Base,
447
- int regSlice) -> Value {
448
- auto regIdx = layout
449
- .apply ({{kRegister , regSlice},
450
- {kLane , 0 },
451
- {kWarp , 0 },
452
- {kBlock , 0 }})[0 ]
453
- .second ;
454
- Value offset = b.xor_ (regBase, b.i32_val (regIdx));
455
- if (paddedSize > 0 ) {
456
- assert (llvm::isPowerOf2_32 (paddedStride));
457
- assert (llvm::isPowerOf2_32 (paddedSize));
458
- auto rshiftVal = llvm::Log2_32 (paddedStride);
459
- auto lshiftVal = llvm::Log2_32 (paddedSize);
460
- offset = b.add (
461
- b.shl (b.lshr (offset, b.i32_val (rshiftVal)), b.i32_val (lshiftVal)),
462
- offset);
463
- }
464
- auto vecAddr = b.gep (sharedPtrTy, elemTy, smemBase, offset,
465
- LLVM::GEPNoWrapFlags::inbounds);
466
- return vecAddr;
467
- };
468
-
469
- auto storeBase = applyLinearLayout (loc, rewriter, shmemStoreLayout,
470
- {{kRegister , b.i32_val (0 )},
471
- {kLane , laneId},
472
- {kWarp , warpId},
473
- {kBlock , b.i32_val (0 )}})[0 ]
474
- .second ;
475
- auto loadBase = applyLinearLayout (loc, rewriter, shmemLoadLayout,
476
- {{kRegister , b.i32_val (0 )},
477
- {kLane , laneId},
478
- {kWarp , warpId},
479
- {kBlock , b.i32_val (0 )}})[0 ]
480
- .second ;
481
- // register idx -> Value
482
- llvm::MapVector<int , Value> outVals;
483
- for (int i = 0 ; i < iterations; i++) {
484
- if (i != 0 )
485
- b.barrier ();
486
-
487
- auto &inRegs = inRegsForIter[i];
488
- auto &outRegs = outRegsForIter[i];
489
-
490
- // When using `stmatrix`, we can store `inVec` elements even if they are
491
- // not contiguous
492
- auto inVec = isStMatrix ? shmemStoreLayout.getNumConsecutiveInOut ()
493
- : scratchConfig.inVec ;
494
- for (int j = 0 ; j < inVals.size () / iterations; j += inVec) {
495
- auto inRegSlice = inRegs[j];
496
- Value vecAddr = getVecAddr (shmemStoreLayout, storeBase, inRegSlice);
497
- SmallVector<Value> inValsVec;
498
- for (int k = 0 ; k < inVec; k++)
499
- inValsVec.push_back (inVals[inRegSlice + k]);
500
- Value valsVec = packLLVector (loc, inValsVec, rewriter);
501
- if (isStMatrix) {
502
- targetInfo.storeMatrixShared (rewriter, loc, vecAddr, valsVec);
503
- } else {
504
- targetInfo.storeDShared (rewriter, loc, vecAddr, std::nullopt, valsVec,
505
- /* pred=*/ b.true_val ());
506
- }
507
- }
508
-
509
- b.barrier ();
510
-
511
- for (int j = 0 ; j < outSize / iterations; j += scratchConfig.outVec ) {
512
- auto outRegSlice = outRegs[j];
513
- auto vecAddr = getVecAddr (shmemLoadLayout, loadBase, outRegSlice);
514
- Value valsVec =
515
- targetInfo.loadDShared (rewriter, loc, vecAddr, std::nullopt,
516
- vec_ty (elemTy, scratchConfig.outVec ),
517
- /* pred=*/ b.true_val ());
518
- for (Value v : unpackLLVector (loc, valsVec, rewriter)) {
519
- if (isInt1) {
520
- // TODO(Intel): special handling for the boolean case required. Does
521
- // this prevent a later optimization that we can't handle, or is
522
- // there something about the layout/SLM loads and stores that
523
- // requires special "transcribing" the boolean to the result of the
524
- // cmp?
525
- outVals[outRegSlice++] =
526
- b.icmp_ne (v, rewriter.create <LLVM::ConstantOp>(
527
- loc, i8_ty, rewriter.getI8IntegerAttr (0 )));
528
- } else {
529
- outVals[outRegSlice++] = v;
530
- }
531
- }
532
- }
533
- }
534
-
535
- SmallVector<Value> outValsVec;
536
- for (size_t i = 0 ; i < outVals.size (); i++)
537
- outValsVec.push_back (outVals[i]);
538
- return outValsVec;
539
- }
540
-
541
- // Determine which registers are read/written in which iteration of the shmem
542
- // transfer specified by `layout`.
543
- SmallVector<SmallVector<int > /* registers*/ >
544
- collectRegsForIter (MLIRContext *ctx, const LinearLayout &layout) const {
545
- StringAttr kRegister = str_attr (" register" );
546
- StringAttr kLane = str_attr (" lane" );
547
- StringAttr kWarp = str_attr (" warp" );
548
- StringAttr kBlock = str_attr (" block" );
549
- StringAttr kIteration = str_attr (" iteration" );
550
-
551
- // The choice of iteration should be determined only by the register. That
552
- // is, it should be correct to split the register dimension into iterations.
553
- assert (layout.sublayoutIsZero ({kLane , kWarp , kBlock }, {kIteration }));
554
-
555
- LinearLayout sublayout = layout.sublayout ({kRegister }, {kIteration });
556
- SmallVector<SmallVector<int >> ret (sublayout.getOutDimSize (kIteration ));
557
- for (int reg = 0 ; reg < sublayout.getInDimSize (kRegister ); reg++) {
558
- auto idx = sublayout.apply ({{kRegister , reg}});
559
- ret[idx.begin ()->second ].push_back (reg);
560
- }
561
- return ret;
562
- }
563
297
};
564
298
565
299
} // namespace
0 commit comments