@@ -150,6 +150,14 @@ translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
150150 }
151151}
152152
153+ //
154+ // Note:
155+ // Block operations for tile of sub byte element types are handled by
156+ // emulating with larger element types.
157+ // Tensor descriptor are keep intact and only ops consuming them are
158+ // emulated
159+ //
160+
153161class CreateNdDescToXeVMPattern
154162 : public OpConversionPattern<xegpu::CreateNdDescOp> {
155163 using OpConversionPattern::OpConversionPattern;
@@ -262,9 +270,57 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
262270 op, " Expected offset rank to match descriptor rank." );
263271 auto elemType = tdescTy.getElementType ();
264272 auto elemBitSize = elemType.getIntOrFloatBitWidth ();
265- if (elemBitSize % 8 != 0 )
273+ bool isSubByte = elemBitSize < 8 ;
274+ uint64_t wScaleFactor = 1 ;
275+
276+ if (!isSubByte && (elemBitSize % 8 != 0 ))
266277 return rewriter.notifyMatchFailure (
267278 op, " Expected element type bit width to be multiple of 8." );
279+ auto tileW = tdescTy.getDimSize (tileRank - 1 );
280+ // For sub byte types, only 4bits are currently supported.
281+ if (isSubByte) {
282+ if (elemBitSize != 4 )
283+ return rewriter.notifyMatchFailure (
284+ op, " Only sub byte types of 4bits are supported." );
285+ if (tileRank != 2 )
286+ return rewriter.notifyMatchFailure (
287+ op, " Sub byte types are only supported for 2D tensor descriptors." );
288+ auto subByteFactor = 8 / elemBitSize;
289+ auto tileH = tdescTy.getDimSize (0 );
290+ // Handle special case for packed load.
291+ if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
292+ if (op.getPacked ().value_or (false )) {
293+ // packed load is implemented as packed loads of 8bit elements.
294+ if (tileH == systolicDepth * 4 &&
295+ tileW == executionSize * subByteFactor) {
296+ // Usage case for loading as Matrix B with pack request.
297+ // source is assumed to pre-packed into 8bit elements
298+ // Emulate with 8bit loads with pack request.
299+ // scaled_tileW = executionSize
300+ elemType = rewriter.getIntegerType (8 );
301+ tileW = executionSize;
302+ wScaleFactor = subByteFactor;
303+ }
304+ }
305+ }
306+ // If not handled by packed load case above, handle other cases.
307+ if (wScaleFactor == 1 ) {
308+ auto sub16BitFactor = subByteFactor * 2 ;
309+ if (tileW == executionSize * sub16BitFactor) {
310+ // Usage case for loading as Matrix A operand
311+ // Emulate with 16bit loads/stores.
312+ // scaled_tileW = executionSize
313+ elemType = rewriter.getIntegerType (16 );
314+ tileW = executionSize;
315+ wScaleFactor = sub16BitFactor;
316+ } else {
317+ return rewriter.notifyMatchFailure (
318+ op, " Unsupported tile shape for sub byte types." );
319+ }
320+ }
321+ // recompute element bit size for emulation.
322+ elemBitSize = elemType.getIntOrFloatBitWidth ();
323+ }
268324
269325 // Get address space from tensor descriptor memory space.
270326 auto ptrTypeLLVM = LLVM::LLVMPointerType::get (
@@ -298,15 +354,27 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
298354 // Convert base pointer (i64) to LLVM pointer type.
299355 Value basePtrLLVM =
300356 LLVM::IntToPtrOp::create (rewriter, loc, ptrTypeLLVM, basePtr);
357+ // FIXME: width or pitch is not the same as baseShapeW it should be the
358+ // stride of the second to last dimension in row major layout.
301359 // Compute width in bytes.
302- Value baseWidthByte =
360+ Value baseShapeWInBytes =
303361 arith::MulIOp::create (rewriter, loc, baseShapeW, elemByteSize);
304362 // Compute pitch in bytes.
305- Value basePitchByte =
363+ Value basePitchBytes =
306364 arith::MulIOp::create (rewriter, loc, basePitch, elemByteSize);
307365
308- // Get tile width from the tensor descriptor type.
309- auto tileW = tdescTy.getDimSize (tileRank - 1 );
366+ if (wScaleFactor > 1 ) {
367+ // Scale offsetW, baseShapeWInBytes for sub byte emulation.
368+ // Note: tileW is already scaled above.
369+ Value wScaleFactorValLog2 = arith::ConstantIntOp::create (
370+ rewriter, loc, rewriter.getI32Type (), llvm::Log2_64 (wScaleFactor));
371+ baseShapeWInBytes = arith::ShRSIOp::create (
372+ rewriter, loc, baseShapeWInBytes, wScaleFactorValLog2);
373+ basePitchBytes = arith::ShRSIOp::create (rewriter, loc, basePitchBytes,
374+ wScaleFactorValLog2);
375+ offsetW =
376+ arith::ShRSIOp::create (rewriter, loc, offsetW, wScaleFactorValLog2);
377+ }
310378 // Get tile height from the tensor descriptor type.
311379 auto tileH = tdescTy.getDimSize (0 );
312380 // Get vblocks from the tensor descriptor type.
@@ -330,17 +398,17 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
330398 auto storeCacheControl =
331399 translateStoreXeGPUCacheHint (op.getL1Hint (), op.getL3Hint ());
332400 xevm::BlockStore2dOp::create (
333- rewriter, loc, basePtrLLVM, baseWidthByte , baseShapeH,
334- basePitchByte , offsetW, offsetH, elemBitSize, tileW, tileH, src,
401+ rewriter, loc, basePtrLLVM, baseShapeWInBytes , baseShapeH,
402+ basePitchBytes , offsetW, offsetH, elemBitSize, tileW, tileH, src,
335403 xevm::StoreCacheControlAttr::get (ctxt, storeCacheControl));
336404 rewriter.eraseOp (op);
337405 } else {
338406 auto loadCacheControl =
339407 translateLoadXeGPUCacheHint (op.getL1Hint (), op.getL3Hint ());
340408 if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
341409 xevm::BlockPrefetch2dOp::create (
342- rewriter, loc, basePtrLLVM, baseWidthByte , baseShapeH,
343- basePitchByte , offsetW, offsetH, elemBitSize, tileW, tileH,
410+ rewriter, loc, basePtrLLVM, baseShapeWInBytes , baseShapeH,
411+ basePitchBytes , offsetW, offsetH, elemBitSize, tileW, tileH,
344412 vblocks, xevm::LoadCacheControlAttr::get (ctxt, loadCacheControl));
345413 rewriter.eraseOp (op);
346414 } else {
@@ -354,9 +422,9 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
354422 : rewriter.getIntegerType (elemBitSize));
355423
356424 Value resultFlatVec = xevm::BlockLoad2dOp::create (
357- rewriter, loc, loadedTy, basePtrLLVM, baseWidthByte, baseShapeH ,
358- basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH ,
359- vblocks, transpose, vnni,
425+ rewriter, loc, loadedTy, basePtrLLVM, baseShapeWInBytes ,
426+ baseShapeH, basePitchBytes, offsetW, offsetH, elemBitSize, tileW,
427+ tileH, vblocks, transpose, vnni,
360428 xevm::LoadCacheControlAttr::get (ctxt, loadCacheControl));
361429 resultFlatVec = vector::BitCastOp::create (
362430 rewriter, loc,
0 commit comments