@@ -59,17 +59,38 @@ static SmallVector<Attribute> getThreadMapping(MLIRContext *ctx) {
5959 return mapping;
6060}
6161
62+ // / Check if the source of an operation comes directly from global memory.
63+ // / Returns false if the source goes through tensor.pad or other local
64+ // / computation that would prevent using global load DMA.
65+ static bool sourceIsFromGlobalMemory (Operation *op) {
66+ Value source;
67+ if (auto copyOp = dyn_cast<linalg::CopyOp>(op)) {
68+ source = copyOp.getInputs ()[0 ];
69+ } else if (auto gatherOp = dyn_cast<IREE::LinalgExt::GatherOp>(op)) {
70+ source = gatherOp.getSource ();
71+ } else {
72+ return false ;
73+ }
74+
75+ // Trace through extract_slice operations to find the origin.
76+ while (auto extractOp = source.getDefiningOp <tensor::ExtractSliceOp>()) {
77+ source = extractOp.getSource ();
78+ }
79+
80+ // If the source comes from tensor.pad, it's not directly from global memory.
81+ if (source.getDefiningOp <tensor::PadOp>()) {
82+ return false ;
83+ }
84+
85+ // Otherwise, assume it's from global memory (e.g., dispatch tensor load).
86+ return true ;
87+ }
88+
6289// / Helper to compute thread number of threads based on translation_info.
6390// / Uses the subgroup_size from translation_info for thread-level tiling.
6491static SmallVector<OpFoldResult>
6592computeThreadNumThreadsImpl (OpBuilder &builder, Operation *op,
6693 RankedTensorType outputType) {
67- // Check that this operation has the use_global_load_dma config.
68- auto dmaConfig = getLoweringConfig<IREE::GPU::UseGlobalLoadDMAAttr>(op);
69- if (!dmaConfig) {
70- return {};
71- }
72-
7394 // Get the function containing this operation.
7495 auto funcOp = op->getParentOfType <FunctionOpInterface>();
7596 if (!funcOp) {
@@ -341,6 +362,11 @@ struct ConvertToCoalescedDMABase : public OpRewritePattern<OpTy> {
341362 return failure ();
342363 }
343364
365+ // Check that source comes from global memory (not tensor.pad).
366+ if (!sourceIsFromGlobalMemory (op)) {
367+ return failure ();
368+ }
369+
344370 SmallVector<OpFoldResult> threadNumThreads =
345371 computeThreadNumThreads (rewriter, op);
346372 if (threadNumThreads.empty ()) {
@@ -386,11 +412,8 @@ struct ConvertGatherToCoalescedDMA
386412 return failure ();
387413 }
388414
389- // For gather ops, tile only the innermost dimension to distribute across
390- // threads.
391- auto dmaConfig =
392- getLoweringConfig<IREE::GPU::UseGlobalLoadDMAAttr>(gatherOp);
393- if (!dmaConfig) {
415+ // Check that source comes from global memory (not tensor.pad).
416+ if (!sourceIsFromGlobalMemory (gatherOp)) {
394417 return failure ();
395418 }
396419
@@ -400,6 +423,12 @@ struct ConvertGatherToCoalescedDMA
400423 return failure ();
401424 }
402425
426+ // Check target supports global load DMA.
427+ IREE::GPU::TargetAttr target = getGPUTargetAttr (funcOp);
428+ if (!target || !IREE::GPU::targetSupportsGlobalLoadDMA (target)) {
429+ return failure ();
430+ }
431+
403432 // Get subgroup size from translation_info.
404433 std::optional<int64_t > subgroupSize = getSubgroupSize (funcOp);
405434 if (!subgroupSize) {
@@ -418,11 +447,6 @@ struct ConvertGatherToCoalescedDMA
418447 Type elementType = outputType.getElementType ();
419448 int64_t elementBits = elementType.getIntOrFloatBitWidth ();
420449
421- IREE::GPU::TargetAttr target = getGPUTargetAttr (funcOp);
422- if (!target) {
423- return failure ();
424- }
425-
426450 ArrayRef<int64_t > dmaSizes;
427451 if (DenseI64ArrayAttr dmaSizesAttr = target.getWgp ().getDmaSizes ()) {
428452 dmaSizes = dmaSizesAttr.asArrayRef ();
@@ -617,10 +641,6 @@ struct GPUConvertToCoalescedDMAPass final
617641 FailureOr<scf::SCFTilingResult> tileAtSubgroupLevel (IRRewriter &rewriter,
618642 OpTy op) {
619643 MLIRContext *context = &getContext ();
620- auto dmaConfig = getLoweringConfig<IREE::GPU::UseGlobalLoadDMAAttr>(op);
621- if (!dmaConfig) {
622- return failure ();
623- }
624644
625645 // Get the function containing this operation.
626646 auto funcOp = op->template getParentOfType <FunctionOpInterface>();
@@ -718,18 +738,26 @@ struct GPUConvertToCoalescedDMAPass final
718738
719739 LogicalResult applySubgroupTiling (FunctionOpInterface funcOp) {
720740 MLIRContext *context = &getContext ();
741+
742+ // Check if target supports global load DMA.
743+ IREE::GPU::TargetAttr target = getGPUTargetAttr (funcOp);
744+ if (!target || !IREE::GPU::targetSupportsGlobalLoadDMA (target)) {
745+ return success ();
746+ }
747+
721748 SmallVector<Operation *> opsToTile;
722749
723- // Collect all ops with iree_gpu.use_global_load_dma lowering config .
750+ // Collect copy/gather ops that are eligible for coalesced DMA .
724751 // Skip ops that are already inside a warp-mapped forall.
725752 funcOp->walk ([&](Operation *op) {
726753 if (isa<linalg::CopyOp, IREE::LinalgExt::GatherOp>(op)) {
727- auto config = getLoweringConfig<IREE::GPU::UseGlobalLoadDMAAttr>(op);
728- if (config) {
729- auto parentForall = op->getParentOfType <scf::ForallOp>();
730- if (!hasWarpMapping (parentForall)) {
731- opsToTile.push_back (op);
732- }
754+ // Check that source comes from global memory (not tensor.pad).
755+ if (!sourceIsFromGlobalMemory (op)) {
756+ return ;
757+ }
758+ auto parentForall = op->getParentOfType <scf::ForallOp>();
759+ if (!hasWarpMapping (parentForall)) {
760+ opsToTile.push_back (op);
733761 }
734762 }
735763 });
0 commit comments