diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp index d6fc4ed07bfab..0db71f866c024 100644 --- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -2298,21 +2298,26 @@ mlir::affine::affineDataCopyGenerate(Block::iterator begin, Block::iterator end, // Walk this range of operations to gather all memory regions. block->walk(begin, end, [&](Operation *opInst) { + Value memref; + MemRefType memrefType; // Gather regions to allocate to buffers in faster memory space. if (auto loadOp = dyn_cast(opInst)) { - if ((filterMemRef.has_value() && filterMemRef != loadOp.getMemRef()) || - (loadOp.getMemRefType().getMemorySpaceAsInt() != - copyOptions.slowMemorySpace)) - return; + memref = loadOp.getMemRef(); + memrefType = loadOp.getMemRefType(); } else if (auto storeOp = dyn_cast(opInst)) { - if ((filterMemRef.has_value() && filterMemRef != storeOp.getMemRef()) || - storeOp.getMemRefType().getMemorySpaceAsInt() != - copyOptions.slowMemorySpace) - return; - } else { - // Neither load nor a store op. - return; + memref = storeOp.getMemRef(); + memrefType = storeOp.getMemRefType(); } + // Neither load nor a store op. + if (!memref) + return; + + auto memorySpaceAttr = + dyn_cast_or_null(memrefType.getMemorySpace()); + if ((filterMemRef.has_value() && filterMemRef != memref) || + (memorySpaceAttr && + memrefType.getMemorySpaceAsInt() != copyOptions.slowMemorySpace)) + return; // Compute the MemRefRegion accessed. auto region = std::make_unique(opInst->getLoc()); diff --git a/mlir/test/Dialect/Affine/affine-data-copy.mlir b/mlir/test/Dialect/Affine/affine-data-copy.mlir index fe3b4a206e2b9..330cf92bafba4 100644 --- a/mlir/test/Dialect/Affine/affine-data-copy.mlir +++ b/mlir/test/Dialect/Affine/affine-data-copy.mlir @@ -333,3 +333,23 @@ func.func @index_elt_type(%arg0: memref<1x2x4x8xindex>) { // CHECK-NEXT: affine.for %{{.*}} = 0 to 8 return } + +#map = affine_map<(d0) -> (d0 + 1)> + +// CHECK-LABEL: func @arbitrary_memory_space +func.func @arbitrary_memory_space() { + %alloc = memref.alloc() : memref<256x8xi8, #spirv.storage_class> + affine.for %arg0 = 0 to 32 step 4 { + %0 = affine.apply #map(%arg0) + affine.for %arg1 = 0 to 8 step 2 { + %1 = affine.apply #map(%arg1) + affine.for %arg2 = 0 to 8 step 2 { + // CHECK: memref.alloc() : memref<1x7xi8> + %2 = affine.apply #map(%arg2) + %3 = affine.load %alloc[%0, %1] : memref<256x8xi8, #spirv.storage_class> + affine.store %3, %alloc[%0, %2] : memref<256x8xi8, #spirv.storage_class> + } + } + } + return +}