@@ -1644,7 +1644,7 @@ int32_t getScaleSel(int32_t blockSize, unsigned bitWidth, int32_t scaleWaveHalf,
16441644 // those values are merged together. (Note: scaleWaveHalf isn't a high-level
16451645 // attribute but is derifed from firstScaleLane).
16461646 assert (llvm::is_contained ({16 , 32 }, blockSize));
1647- assert (llvm::is_contained (llvm::ArrayRef< unsigned >{ 4 , 6 , 8 }, bitWidth));
1647+ assert (llvm::is_contained ({ 4u , 6u , 8u }, bitWidth));
16481648
16491649 const bool isFp8 = bitWidth == 8 ;
16501650 const bool isBlock16 = blockSize == 16 ;
@@ -2276,72 +2276,106 @@ struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
22762276 }
22772277};
22782278
2279- struct AMDGPUMakeDmaBaseLowering
2280- : public ConvertOpToLLVMPattern<MakeDmaBaseOp> {
2281- using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
2279+ static Value setValueAtOffset (ConversionPatternRewriter &rewriter, Location loc,
2280+ Value accumulator, Value value, int64_t shift) {
2281+ shift = shift % 32 ;
2282+ Value shiftAmount;
2283+ if (shift != 0 ) {
2284+ shiftAmount = createI32Constant (rewriter, loc, shift % 32 );
2285+ value = LLVM::ShlOp::create (rewriter, loc, value, shiftAmount);
2286+ }
2287+
2288+ if (matchPattern (accumulator, mlir::m_Zero ()))
2289+ return value;
2290+
2291+ constexpr bool isDisjoint = true ;
2292+ return LLVM::OrOp::create (rewriter, loc, accumulator, value, isDisjoint);
2293+ }
2294+
2295+ template <typename BaseOp>
2296+ struct AMDGPUMakeDmaBaseLowering : public ConvertOpToLLVMPattern <BaseOp> {
2297+ using ConvertOpToLLVMPattern<BaseOp>::ConvertOpToLLVMPattern;
2298+ using Adaptor = typename ConvertOpToLLVMPattern<BaseOp>::OpAdaptor;
22822299
22832300 AMDGPUMakeDmaBaseLowering (const LLVMTypeConverter &converter, Chipset chipset)
2284- : ConvertOpToLLVMPattern<MakeDmaBaseOp >(converter), chipset(chipset) {}
2301+ : ConvertOpToLLVMPattern<BaseOp >(converter), chipset(chipset) {}
22852302 Chipset chipset;
22862303
22872304 LogicalResult
2288- matchAndRewrite (MakeDmaBaseOp op, OpAdaptor adaptor,
2305+ matchAndRewrite (BaseOp op, Adaptor adaptor,
22892306 ConversionPatternRewriter &rewriter) const override {
22902307 if (chipset < kGfx1250 )
22912308 return op->emitOpError (" make_dma_base is only supported on gfx1250" );
22922309
22932310 Location loc = op.getLoc ();
22942311
2312+ constexpr int32_t constlen = 4 ;
2313+ Value consts[constlen];
2314+ for (int64_t i = 0 ; i < constlen; i++)
2315+ consts[i] = createI32Constant (rewriter, loc, i);
2316+
2317+ constexpr int32_t sgprslen = constlen;
2318+ Value sgprs[sgprslen];
2319+ for (int64_t i = 0 ; i < sgprslen; i++) {
2320+ sgprs[i] = consts[0 ];
2321+ }
2322+
2323+ sgprs[0 ] = consts[1 ];
2324+
2325+ if (op.isGather ()) {
2326+ sgprs[0 ] = setValueAtOffset (rewriter, loc, sgprs[0 ], consts[1 ], 30 );
2327+
2328+ auto type = cast<TDMGatherBaseType>(op.getResult ().getType ());
2329+ Type indexType = type.getIndexType ();
2330+ unsigned indexSize = indexType.getIntOrFloatBitWidth ();
2331+ assert (llvm::is_contained ({16u , 32u }, indexSize) &&
2332+ " expected index_size to be 16 or 32" );
2333+ unsigned idx = (indexSize / 16 ) - 1 ;
2334+
2335+ if (idx)
2336+ sgprs[0 ] = setValueAtOffset (rewriter, loc, sgprs[0 ], consts[1 ], 31 );
2337+ }
2338+
22952339 ValueRange ldsIndices = adaptor.getLdsIndices ();
22962340 Value lds = adaptor.getLds ();
22972341 auto ldsMemRefType = cast<MemRefType>(op.getLds ().getType ());
22982342
2299- Value ldsPtr =
2300- getStridedElementPtr ( rewriter, loc, ldsMemRefType, lds, ldsIndices);
2343+ Value ldsPtr = ConvertToLLVMPattern::getStridedElementPtr (
2344+ rewriter, loc, ldsMemRefType, lds, ldsIndices);
23012345
23022346 ValueRange globalIndices = adaptor.getGlobalIndices ();
23032347 Value global = adaptor.getGlobal ();
23042348 auto globalMemRefType = cast<MemRefType>(op.getGlobal ().getType ());
23052349
2306- Value globalPtr = getStridedElementPtr (rewriter, loc, globalMemRefType,
2307- global, globalIndices);
2350+ Value globalPtr = ConvertToLLVMPattern:: getStridedElementPtr (
2351+ rewriter, loc, globalMemRefType, global, globalIndices);
23082352
23092353 Type i32 = rewriter.getI32Type ();
23102354 Type i64 = rewriter.getI64Type ();
23112355
2312- Value castForLdsAddr = LLVM::PtrToIntOp::create (rewriter, loc, i32 , ldsPtr);
2356+ sgprs[ 1 ] = LLVM::PtrToIntOp::create (rewriter, loc, i32 , ldsPtr);
23132357 Value castForGlobalAddr =
23142358 LLVM::PtrToIntOp::create (rewriter, loc, i64 , globalPtr);
23152359
2316- Value lowHalf =
2317- LLVM::TruncOp::create (rewriter, loc, i32 , castForGlobalAddr);
2360+ sgprs[2 ] = LLVM::TruncOp::create (rewriter, loc, i32 , castForGlobalAddr);
23182361
23192362 Value shift = LLVM::LShrOp::create (rewriter, loc, castForGlobalAddr,
23202363 createI64Constant (rewriter, loc, 32 ));
23212364
23222365 Value highHalf = LLVM::TruncOp::create (rewriter, loc, i32 , shift);
23232366
23242367 Value mask = createI32Constant (rewriter, loc, (1ull << 25 ) - 1 );
2325- Value validHighHalf = LLVM::AndOp::create (rewriter, loc, highHalf, mask);
2368+ highHalf = LLVM::AndOp::create (rewriter, loc, highHalf, mask);
23262369
2327- Value typeField = createI32Constant (rewriter, loc, 2 << 30 );
2328- Value highHalfPlusType =
2329- LLVM::OrOp::create (rewriter, loc, validHighHalf, typeField);
2330-
2331- Value c0 = createI32Constant (rewriter, loc, 0 );
2332- Value c1 = createI32Constant (rewriter, loc, 1 );
2333- Value c2 = createI32Constant (rewriter, loc, 2 );
2334- Value c3 = createI32Constant (rewriter, loc, 3 );
2370+ sgprs[3 ] = setValueAtOffset (rewriter, loc, highHalf, consts[2 ], 30 );
23352371
23362372 Type v4i32 = this ->typeConverter ->convertType (VectorType::get (4 , i32 ));
23372373 assert (v4i32 && " expected type conversion to succeed" );
23382374 Value result = LLVM::PoisonOp::create (rewriter, loc, v4i32);
2339- result = LLVM::InsertElementOp::create (rewriter, loc, result, c1, c0);
2340- result = LLVM::InsertElementOp::create (rewriter, loc, result,
2341- castForLdsAddr, c1);
2342- result = LLVM::InsertElementOp::create (rewriter, loc, result, lowHalf, c2);
2343- result = LLVM::InsertElementOp::create (rewriter, loc, result,
2344- highHalfPlusType, c3);
2375+
2376+ for (auto [sgpr, constant] : llvm::zip_equal (sgprs, consts))
2377+ result =
2378+ LLVM::InsertElementOp::create (rewriter, loc, result, sgpr, constant);
23452379
23462380 rewriter.replaceOp (op, result);
23472381 return success ();
@@ -2360,21 +2394,6 @@ struct AMDGPUMakeDmaDescriptorLowering
23602394
23612395 Value getDGroup0 (OpAdaptor adaptor) const { return adaptor.getBase (); }
23622396
2363- Value setValueAtOffset (ConversionPatternRewriter &rewriter, Location loc,
2364- Value accumulator, Value value, int64_t shift) const {
2365- shift = shift % 32 ;
2366- Value shiftAmount;
2367- if (shift != 0 ) {
2368- shiftAmount = createI32Constant (rewriter, loc, shift % 32 );
2369- value = LLVM::ShlOp::create (rewriter, loc, value, shiftAmount);
2370- }
2371-
2372- if (matchPattern (accumulator, mlir::m_Zero ()))
2373- return value;
2374-
2375- return LLVM::OrOp::create (rewriter, loc, accumulator, value);
2376- }
2377-
23782397 Value setWorkgroupMask (MakeDmaDescriptorOp op, OpAdaptor adaptor,
23792398 ConversionPatternRewriter &rewriter, Location loc,
23802399 Value sgpr0) const {
@@ -2393,9 +2412,8 @@ struct AMDGPUMakeDmaDescriptorLowering
23932412 ConversionPatternRewriter &rewriter, Location loc,
23942413 Value sgpr0, ArrayRef<Value> consts) const {
23952414 unsigned elementTypeWidthInBits = op.getElementTypeWidth ();
2396- assert (
2397- llvm::is_contained<unsigned >({8 , 16 , 32 , 64 }, elementTypeWidthInBits) &&
2398- " expected type width to be 8, 16, 32, or 64." );
2415+ assert (llvm::is_contained ({8u , 16u , 32u , 64u }, elementTypeWidthInBits) &&
2416+ " expected type width to be 8, 16, 32, or 64." );
23992417 int64_t idx = llvm::Log2_32 (elementTypeWidthInBits / 8 );
24002418 Value size = consts[idx];
24012419 return setValueAtOffset (rewriter, loc, sgpr0, size, 16 );
@@ -3055,7 +3073,8 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
30553073 ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
30563074 PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
30573075 GatherToLDSOpLowering, TransposeLoadOpLowering, AMDGPUPermlaneLowering,
3058- AMDGPUMakeDmaBaseLowering, AMDGPUMakeDmaDescriptorLowering>(converter,
3059- chipset);
3076+ AMDGPUMakeDmaBaseLowering<MakeDmaBaseOp>,
3077+ AMDGPUMakeDmaBaseLowering<MakeGatherDmaBaseOp>,
3078+ AMDGPUMakeDmaDescriptorLowering>(converter, chipset);
30603079 patterns.add <AMDGPUSwizzleBitModeLowering>(converter);
30613080}
0 commit comments