@@ -57,39 +57,6 @@ getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
5757 return std::make_pair (sgShape, count);
5858}
5959
60- // Calculate offset for each subgroup
61- static SmallVector<OpFoldResult>
62- calculateGlobalOffsets (ConversionPatternRewriter &rewriter, Location loc,
63- const SmallVector<OpFoldResult> &originalOffsets,
64- const SmallVector<Value> &localOffset,
65- const SmallVector<int64_t > &distUnitBaseAddr,
66- const SmallVector<int64_t > &distUnitShape) {
67- assert (localOffset.size () == distUnitBaseAddr.size () &&
68- " localOffset and distUnitBaseAddr must have the same rank" );
69-
70- SmallVector<OpFoldResult> globalOffsets (originalOffsets.begin (),
71- originalOffsets.end ());
72- size_t rank = localOffset.size ();
73- for (size_t i = 0 ; i < rank; ++i) {
74- size_t dimIdx = originalOffsets.size () - rank + i;
75- Value constOffset =
76- rewriter.create <arith::ConstantIndexOp>(loc, distUnitBaseAddr[i]);
77- Value offset =
78- rewriter.createOrFold <index::AddOp>(loc, localOffset[i], constOffset);
79- Value modValue =
80- rewriter.create <arith::ConstantIndexOp>(loc, distUnitShape[i]);
81- Value offsetMod =
82- rewriter.createOrFold <index::RemUOp>(loc, offset, modValue);
83- Value origOffset =
84- getValueOrCreateConstantIndexOp (rewriter, loc, originalOffsets[dimIdx]);
85- Value globalOffset =
86- rewriter.createOrFold <index::AddOp>(loc, origOffset, offsetMod);
87- globalOffsets[dimIdx] = globalOffset;
88- }
89-
90- return globalOffsets;
91- }
92-
9360// / This pattern transforms the CreateNdDescOp to create a subgroup descriptor
9461// / from a workgroup descriptor. It replaces the offsets and sizes with
9562// / appropriate values for the subgroup.
@@ -138,6 +105,39 @@ calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc,
138105struct WgToSgCreateNdOp : public OpConversionPattern <xegpu::CreateNdDescOp> {
139106 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
140107
108+ // Calculate offset for each subgroup
109+ static SmallVector<OpFoldResult>
110+ calculateGlobalOffsets (ConversionPatternRewriter &rewriter, Location loc,
111+ const SmallVector<OpFoldResult> &originalOffsets,
112+ const SmallVector<Value> &localOffset,
113+ const SmallVector<int64_t > &distUnitBaseAddr,
114+ const SmallVector<int64_t > &distUnitShape) {
115+ assert (localOffset.size () == distUnitBaseAddr.size () &&
116+ " localOffset and distUnitBaseAddr must have the same rank" );
117+
118+ SmallVector<OpFoldResult> globalOffsets (originalOffsets.begin (),
119+ originalOffsets.end ());
120+ size_t rank = localOffset.size ();
121+ for (size_t i = 0 ; i < rank; ++i) {
122+ size_t dimIdx = originalOffsets.size () - rank + i;
123+ Value constOffset =
124+ rewriter.create <arith::ConstantIndexOp>(loc, distUnitBaseAddr[i]);
125+ Value offset =
126+ rewriter.createOrFold <index::AddOp>(loc, localOffset[i], constOffset);
127+ Value modValue =
128+ rewriter.create <arith::ConstantIndexOp>(loc, distUnitShape[i]);
129+ Value offsetMod =
130+ rewriter.createOrFold <index::RemUOp>(loc, offset, modValue);
131+ Value origOffset =
132+ getValueOrCreateConstantIndexOp (rewriter, loc, originalOffsets[dimIdx]);
133+ Value globalOffset =
134+ rewriter.createOrFold <index::AddOp>(loc, origOffset, offsetMod);
135+ globalOffsets[dimIdx] = globalOffset;
136+ }
137+
138+ return globalOffsets;
139+ }
140+
141141 LogicalResult
142142 matchAndRewrite (xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
143143 ConversionPatternRewriter &rewriter) const override {
@@ -390,21 +390,6 @@ struct WgToSgElementwiseOp : public ConversionPattern {
390390 }
391391};
392392
393- // based on the size of the given vector type
394- static TypedValue<MemRefType>
395- allocateSLMBuffer (ConversionPatternRewriter &rewriter, Location loc,
396- VectorType type) {
397- int64_t bits = type.getElementType ().getIntOrFloatBitWidth ();
398- int64_t slmSizeInBytes = type.getNumElements () * bits / 8 ;
399- auto slmTy = MemRefType::get (slmSizeInBytes, rewriter.getI8Type (), {}, 3 );
400- auto slm = rewriter.create <memref::AllocOp>(loc, slmTy);
401- auto viewTy = MemRefType::get (type.getShape (), type.getElementType (), {}, 3 );
402- auto view = rewriter.create <memref::ViewOp>(
403- loc, viewTy, slm, rewriter.create <arith::ConstantIndexOp>(loc, 0 ),
404- ValueRange ());
405- return view;
406- }
407-
408393struct WgToSgConvertLayoutOp
409394 : public OpConversionPattern<xegpu::ConvertLayoutOp> {
410395 using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
@@ -418,115 +403,29 @@ struct WgToSgConvertLayoutOp
418403 return rewriter.notifyMatchFailure (
419404 op, " Input and target layouts must have subgroup layout" );
420405
421- // initialize values with the source values
422- SmallVector<Value> values (adaptor.getSource ());
423-
424- Location loc = op.getLoc ();
425- MLIRContext *ctx = op.getContext ();
426- VectorType type = op.getResult ().getType ();
427- ArrayRef<int64_t > shape = type.getShape ();
428-
429406 DenseI32ArrayAttr inputSgLayout = input.getSgLayout ();
430407 DenseI32ArrayAttr inputSgData = input.getSgData ();
431408 DenseI32ArrayAttr targetSgLayout = target.getSgLayout ();
432409 DenseI32ArrayAttr targetSgData = target.getSgData ();
433410
434- // we only need SLM support when input and target layouts are different
435- if (inputSgLayout != targetSgLayout || inputSgData != targetSgData) {
436- values.clear ();
437- rewriter.setInsertionPoint (op);
438- TypedValue<MemRefType> slmBuffer = allocateSLMBuffer (rewriter, loc, type);
439-
440- auto linearSgId = rewriter.create <gpu::SubgroupIdOp>(
441- loc, rewriter.getIndexType (), nullptr );
442-
443- { // store to slm buffer
444- SmallVector<int64_t > sgLayout =
445- llvm::to_vector_of<int64_t >(input.getSgLayout ().asArrayRef ());
446- SmallVector<int64_t > sgShape = getSgShapeAndCount (shape, input).first ;
447- auto delinearized = affine::delinearizeIndex (
448- rewriter, loc, linearSgId, getAsIndexOpFoldResult (ctx, sgLayout));
449- if (failed (delinearized))
450- return rewriter.notifyMatchFailure (op, " Failed to delinearize sgId" );
451- SmallVector<Value> sgIds = *delinearized;
452-
453- SmallVector<int64_t > distUnitShape (sgLayout.size ());
454- SmallVector<Value> localOffset (sgLayout.size ());
455- for (size_t i = 0 ; i < sgLayout.size (); i++) {
456- distUnitShape[i] = std::min (sgLayout[i] * sgShape[i], shape[i]);
457- localOffset[i] = rewriter.createOrFold <index::MulOp>(
458- loc, sgIds[i],
459- rewriter.create <arith::ConstantIndexOp>(loc, sgShape[i]));
460- }
461-
462- auto tdescTy = xegpu::TensorDescType::get (
463- sgShape, type.getElementType (), 1 , false , xegpu::MemorySpace::SLM,
464- input.dropSgLayoutAndData ());
465-
466- SmallVector<OpFoldResult> zeros = getAsIndexOpFoldResult (
467- ctx, SmallVector<int64_t >(sgLayout.size (), 0 ));
468- for (auto [data, baseOffsets] :
469- llvm::zip_equal (adaptor.getSource (),
470- StaticTileOffsetRange (shape, distUnitShape))) {
471- SmallVector<OpFoldResult> offsets = calculateGlobalOffsets (
472- rewriter, loc, zeros, localOffset, baseOffsets, distUnitShape);
473- auto tdesc = rewriter.create <xegpu::CreateNdDescOp>(
474- loc, tdescTy, slmBuffer, offsets);
475- rewriter.create <xegpu::StoreNdOp>(loc, data, tdesc, nullptr , nullptr ,
476- nullptr );
477- }
478- }
479-
480- rewriter.create <gpu::BarrierOp>(loc);
481-
482- { // load from SLM
483- SmallVector<int64_t > sgLayout =
484- llvm::to_vector_of<int64_t >(target.getSgLayout ().asArrayRef ());
485- SmallVector<int64_t > sgShape = getSgShapeAndCount (shape, target).first ;
486- auto delinearized = affine::delinearizeIndex (
487- rewriter, loc, linearSgId, getAsIndexOpFoldResult (ctx, sgLayout));
488- if (failed (delinearized))
489- return rewriter.notifyMatchFailure (op, " Failed to delinearize sgId" );
490- SmallVector<Value> sgIds = *delinearized;
491-
492- SmallVector<int64_t > distUnitShape (sgLayout.size ());
493- SmallVector<Value> localOffset (sgLayout.size ());
494- for (size_t i = 0 ; i < sgLayout.size (); i++) {
495- distUnitShape[i] = std::min (sgLayout[i] * sgShape[i], shape[i]);
496- localOffset[i] = rewriter.createOrFold <index::MulOp>(
497- loc, sgIds[i],
498- rewriter.create <arith::ConstantIndexOp>(loc, sgShape[i]));
499- }
500-
501- auto tdescTy = xegpu::TensorDescType::get (
502- sgShape, type.getElementType (), 1 , false , xegpu::MemorySpace::SLM,
503- target.dropSgLayoutAndData ());
504- auto valueTy = VectorType::get (sgShape, type.getElementType ());
505-
506- SmallVector<OpFoldResult> zeros = getAsIndexOpFoldResult (
507- ctx, SmallVector<int64_t >(sgLayout.size (), 0 ));
508- for (auto baseOffsets : StaticTileOffsetRange (shape, distUnitShape)) {
509- SmallVector<OpFoldResult> offsets = calculateGlobalOffsets (
510- rewriter, loc, zeros, localOffset, baseOffsets, distUnitShape);
511- auto tdesc = rewriter.create <xegpu::CreateNdDescOp>(
512- loc, tdescTy, slmBuffer, offsets);
513- auto newOp = rewriter.create <xegpu::LoadNdOp>(
514- loc, TypeRange ({valueTy}), ValueRange ({tdesc}));
515- values.push_back (newOp);
516- }
517- }
518- }
411+ // TODO: currently we only support for optimal case, where input and
412+ // output has the same sg_layout and sg_data, so SLM is not involved.
413+ if (inputSgLayout != targetSgLayout || inputSgData != targetSgData)
414+ return failure ();
519415
520416 input = input.dropSgLayoutAndData ();
521417 target = target.dropSgLayoutAndData ();
522418
523- SmallVector<Value> newOps;
524- for (auto src : values) {
525- auto newOp = rewriter.create <xegpu::ConvertLayoutOp>(
526- op.getLoc (), src.getType (), src, input, target);
527- newOps.push_back (newOp);
419+ SmallVector<Value> newOps (adaptor.getSource ());
420+
421+ if (input && target) {
422+ for (auto [i, src] : llvm::enumerate (adaptor.getSource ())) {
423+ auto newOp = rewriter.create <xegpu::ConvertLayoutOp>(
424+ op.getLoc (), src.getType (), src, input, target);
425+ newOps[i] = newOp;
426+ }
528427 }
529- rewriter.replaceOpWithMultiple (op, newOps);
428+ rewriter.replaceOpWithMultiple (op, { newOps} );
530429 return success ();
531430 }
532431};
0 commit comments