|
12 | 12 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
13 | 13 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
14 | 14 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
| 15 | +#include "mlir/Dialect/GPU/Utils/DistributionUtils.h" |
15 | 16 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
16 | 17 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
17 | 18 | #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" |
@@ -66,12 +67,12 @@ bool divisible(APInt lhs, APInt rhs) { return !lhs.urem(rhs); } |
66 | 67 | // /// -> !xegpu.tensor_desc<4x1xf32> |
67 | 68 | // /// |
68 | 69 | // /// ``` |
69 | | -// struct WarpOpTensorDescOp final |
70 | | -// : public OpRewritePattern<vector::WarpExecuteOnLane0Op> { |
71 | | -// using OpRewritePattern<vector::WarpExecuteOnLane0Op>::OpRewritePattern; |
72 | | -// LogicalResult matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp, |
73 | | -// PatternRewriter &rewriter) const override; |
74 | | -// }; |
| 70 | +struct WarpOpTensorDescOp final |
| 71 | + : public OpRewritePattern<gpu::WarpExecuteOnLane0Op> { |
| 72 | + using OpRewritePattern<gpu::WarpExecuteOnLane0Op>::OpRewritePattern; |
| 73 | + LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, |
| 74 | + PatternRewriter &rewriter) const override; |
| 75 | +}; |
75 | 76 |
|
76 | 77 | // /// Sink a store_nd feeding into vector.yield op for the enclosing |
77 | 78 | // /// `vector.warp_execute_on_lane_0`. In case arguments for the store are passed |
@@ -159,32 +160,32 @@ bool divisible(APInt lhs, APInt rhs) { return !lhs.urem(rhs); } |
159 | 160 | // return newVectorType; |
160 | 161 | // } |
161 | 162 |
|
162 | | -// FailureOr<xegpu::TensorDescType> |
163 | | -// getDistributedTensorDescType(xegpu::TensorDescType originalT, |
164 | | -// xegpu::SGMapAttr sgMap, |
165 | | -// xegpu::MemorySpace memSpace) { |
166 | | -// llvm::SmallVector<int64_t, 2> distributedShape; |
167 | | -// auto layout = sgMap.getWiLayout(); |
168 | | -// auto shape = originalT.getShape(); |
169 | | -// for (const auto [l, o] : llvm::zip_equal(layout, shape)) { |
170 | | -// if (!divisible(APInt(64, o), APInt(64, l))) |
171 | | -// return failure(); |
172 | | -// distributedShape.push_back(o / l); |
173 | | -// } |
174 | | -// xegpu::TensorDescType distributedDescType; |
175 | | -// if (originalT.isScattered()) { |
176 | | - |
177 | | -// distributedDescType = xegpu::TensorDescType::get( |
178 | | -// distributedShape, originalT.getElementType(), originalT.getChunkSize(), |
179 | | -// originalT.getMemorySpace(), originalT.getSGMapAttr()); |
180 | | -// } else { |
181 | | -// distributedDescType = xegpu::TensorDescType::get( |
182 | | -// distributedShape, originalT.getElementType(), |
183 | | -// originalT.getBoundaryCheck(), originalT.getArrayLength(), |
184 | | -// originalT.getMemorySpace(), originalT.getSGMapAttr()); |
185 | | -// } |
186 | | -// return distributedDescType; |
187 | | -// } |
| 163 | +FailureOr<xegpu::TensorDescType> |
| 164 | +getDistributedTensorDescType(xegpu::TensorDescType originalT, |
| 165 | + xegpu::SGMapAttr sgMap, |
| 166 | + xegpu::MemorySpace memSpace) { |
| 167 | + llvm::SmallVector<int64_t, 2> distributedShape; |
| 168 | + auto layout = sgMap.getWiLayout(); |
| 169 | + auto shape = originalT.getShape(); |
| 170 | + for (const auto [l, o] : llvm::zip_equal(layout, shape)) { |
| 171 | + if (!divisible(APInt(64, o), APInt(64, l))) |
| 172 | + return failure(); |
| 173 | + distributedShape.push_back(o / l); |
| 174 | + } |
| 175 | + xegpu::TensorDescType distributedDescType; |
| 176 | + if (originalT.isScattered()) { |
| 177 | + |
| 178 | + distributedDescType = xegpu::TensorDescType::get( |
| 179 | + distributedShape, originalT.getElementType(), originalT.getChunkSize(), |
| 180 | + originalT.getMemorySpace(), originalT.getSGMapAttr()); |
| 181 | + } else { |
| 182 | + distributedDescType = xegpu::TensorDescType::get( |
| 183 | + distributedShape, originalT.getElementType(), |
| 184 | + originalT.getBoundaryCheck(), originalT.getArrayLength(), |
| 185 | + originalT.getMemorySpace(), originalT.getSGMapAttr()); |
| 186 | + } |
| 187 | + return distributedDescType; |
| 188 | +} |
188 | 189 | } // namespace |
189 | 190 |
|
190 | 191 | // LogicalResult |
@@ -303,91 +304,91 @@ bool divisible(APInt lhs, APInt rhs) { return !lhs.urem(rhs); } |
303 | 304 | // return success(); |
304 | 305 | // } |
305 | 306 |
|
306 | | -// LogicalResult |
307 | | -// WarpOpTensorDescOp::matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp, |
308 | | -// PatternRewriter &rewriter) const { |
309 | | -// OpOperand *operand = getWarpResult(warpOp, [](Operation *op) { |
310 | | -// return isa<xegpu::CreateNdDescOp>(op) && op->hasOneUse(); |
311 | | -// }); |
312 | | - |
313 | | -// if (!operand) |
314 | | -// return rewriter.notifyMatchFailure( |
315 | | -// warpOp, "warp result is not a xegpu::CreateNdDesc op"); |
316 | | -// auto descOp = operand->get().getDefiningOp<xegpu::CreateNdDescOp>(); |
317 | | -// assert(descOp && "desc op must be not null"); |
318 | | -// unsigned operandIdx = operand->getOperandNumber(); |
319 | | - |
320 | | -// // TODO: is memref uniform in the region |
321 | | -// rewriter.setInsertionPoint(warpOp); |
322 | | -// auto srcTypedVal = dyn_cast<TypedValue<MemRefType>>(descOp.getSource()); |
323 | | -// assert(srcTypedVal && "source value must be not null"); |
324 | | - |
325 | | -// auto descOffsets = descOp.getMixedOffsets(); |
326 | | -// if (descOffsets.size() != 2) |
327 | | -// return rewriter.notifyMatchFailure(descOp, |
328 | | -// "offsets size is expected to be 2"); |
329 | | - |
330 | | -// xegpu::SGMapAttr sgMap = descOp.getType().getSGMapAttr(); |
331 | | -// if (!sgMap) |
332 | | -// return rewriter.notifyMatchFailure( |
333 | | -// descOp, "the tensor descriptor lacks sg_map attribute"); |
334 | | - |
335 | | -// auto layout = sgMap.getWiLayout(); |
336 | | - |
337 | | -// // Calculate the offset within tensor descriptor for the current lane_id. The |
338 | | -// // access to proper element for a work item is done through a lane-specific |
339 | | -// // subview (tdesc offsets are used as base, lane shift is added on top). |
340 | | -// auto laneid = warpOp.getLaneid(); |
341 | | -// auto xDim = |
342 | | -// rewriter.create<arith::ConstantIndexOp>(laneid.getLoc(), layout[0]); |
343 | | -// auto shiftx = rewriter.create<arith::RemUIOp>(laneid.getLoc(), laneid, xDim); |
344 | | -// auto shifty = rewriter.create<arith::DivUIOp>(laneid.getLoc(), laneid, xDim); |
345 | | - |
346 | | -// auto basex = getValueOrCreateConstantIndexOp(rewriter, laneid.getLoc(), |
347 | | -// descOffsets[0]); |
348 | | -// auto basey = getValueOrCreateConstantIndexOp(rewriter, laneid.getLoc(), |
349 | | -// descOffsets[1]); |
350 | | -// auto offsetx = rewriter.create<arith::AddIOp>(laneid.getLoc(), shiftx, basex); |
351 | | -// auto offsety = rewriter.create<arith::AddIOp>(laneid.getLoc(), shifty, basey); |
352 | | - |
353 | | -// auto distributedDescTypeOrFailure = getDistributedTensorDescType( |
354 | | -// descOp.getType(), sgMap, descOp.getType().getMemorySpace()); |
355 | | -// if (failed(distributedDescTypeOrFailure)) |
356 | | -// return rewriter.notifyMatchFailure(descOp, |
357 | | -// "Failed to distribute the desc type"); |
358 | | -// xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value(); |
359 | | -// auto distributedShape = newTDescType.getShape(); |
360 | | -// // use the base memref strides |
361 | | -// SmallVector<OpFoldResult> overwriteStrides = |
362 | | -// getAsIndexOpFoldResult(rewriter.getContext(), SmallVector<int64_t>{1, 1}); |
363 | | -// SmallVector<OpFoldResult> overwriteSizes = |
364 | | -// getAsIndexOpFoldResult(rewriter.getContext(), distributedShape); |
365 | | - |
366 | | -// SmallVector<size_t> newRetIndices; |
367 | | -// vector::WarpExecuteOnLane0Op newWarpOp = |
368 | | -// moveRegionToNewWarpOpAndAppendReturns( |
369 | | -// rewriter, warpOp, descOp.getSource(), descOp.getSourceType(), |
370 | | -// newRetIndices); |
371 | | - |
372 | | -// rewriter.setInsertionPointAfter(newWarpOp); |
373 | | -// auto subview = rewriter.create<memref::SubViewOp>( |
374 | | -// newWarpOp.getLoc(), srcTypedVal, getAsOpFoldResult({offsetx, offsety}), |
375 | | -// overwriteSizes, overwriteStrides); |
376 | | -// subview.getSourceMutable().assign(newWarpOp.getResult(newRetIndices[0])); |
377 | | - |
378 | | -// auto zero = rewriter.create<arith::ConstantIndexOp>(laneid.getLoc(), 0); |
379 | | -// auto newDescOp = rewriter.create<xegpu::CreateNdDescOp>( |
380 | | -// newWarpOp.getLoc(), newTDescType, subview, |
381 | | -// getAsOpFoldResult({zero, zero})); |
382 | | - |
383 | | -// Value distributedVal = newWarpOp.getResult(operandIdx); |
384 | | -// rewriter.replaceAllUsesWith(distributedVal, newDescOp); |
385 | | - |
386 | | -// return success(); |
387 | | -// } |
| 307 | +LogicalResult |
| 308 | +WarpOpTensorDescOp::matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, |
| 309 | + PatternRewriter &rewriter) const { |
| 310 | + OpOperand *operand = getWarpResult(warpOp, [](Operation *op) { |
| 311 | + return isa<xegpu::CreateNdDescOp>(op) && op->hasOneUse(); |
| 312 | + }); |
| 313 | + |
| 314 | + if (!operand) |
| 315 | + return rewriter.notifyMatchFailure( |
| 316 | + warpOp, "warp result is not a xegpu::CreateNdDesc op"); |
| 317 | + auto descOp = operand->get().getDefiningOp<xegpu::CreateNdDescOp>(); |
| 318 | + assert(descOp && "desc op must be not null"); |
| 319 | + unsigned operandIdx = operand->getOperandNumber(); |
| 320 | + |
| 321 | + // TODO: is memref uniform in the region |
| 322 | + rewriter.setInsertionPoint(warpOp); |
| 323 | + auto srcTypedVal = dyn_cast<TypedValue<MemRefType>>(descOp.getSource()); |
| 324 | + assert(srcTypedVal && "source value must be not null"); |
| 325 | + |
| 326 | + auto descOffsets = descOp.getMixedOffsets(); |
| 327 | + if (descOffsets.size() != 2) |
| 328 | + return rewriter.notifyMatchFailure(descOp, |
| 329 | + "offsets size is expected to be 2"); |
| 330 | + |
| 331 | + xegpu::SGMapAttr sgMap = descOp.getType().getSGMapAttr(); |
| 332 | + if (!sgMap) |
| 333 | + return rewriter.notifyMatchFailure( |
| 334 | + descOp, "the tensor descriptor lacks sg_map attribute"); |
| 335 | + |
| 336 | + auto layout = sgMap.getWiLayout(); |
| 337 | + |
| 338 | + // Calculate the offset within tensor descriptor for the current lane_id. The |
| 339 | + // access to proper element for a work item is done through a lane-specific |
| 340 | + // subview (tdesc offsets are used as base, lane shift is added on top). |
| 341 | + auto laneid = warpOp.getLaneid(); |
| 342 | + auto xDim = |
| 343 | + rewriter.create<arith::ConstantIndexOp>(laneid.getLoc(), layout[0]); |
| 344 | + auto shiftx = rewriter.create<arith::RemUIOp>(laneid.getLoc(), laneid, xDim); |
| 345 | + auto shifty = rewriter.create<arith::DivUIOp>(laneid.getLoc(), laneid, xDim); |
| 346 | + |
| 347 | + auto basex = getValueOrCreateConstantIndexOp(rewriter, laneid.getLoc(), |
| 348 | + descOffsets[0]); |
| 349 | + auto basey = getValueOrCreateConstantIndexOp(rewriter, laneid.getLoc(), |
| 350 | + descOffsets[1]); |
| 351 | + auto offsetx = rewriter.create<arith::AddIOp>(laneid.getLoc(), shiftx, basex); |
| 352 | + auto offsety = rewriter.create<arith::AddIOp>(laneid.getLoc(), shifty, basey); |
| 353 | + |
| 354 | + auto distributedDescTypeOrFailure = getDistributedTensorDescType( |
| 355 | + descOp.getType(), sgMap, descOp.getType().getMemorySpace()); |
| 356 | + if (failed(distributedDescTypeOrFailure)) |
| 357 | + return rewriter.notifyMatchFailure(descOp, |
| 358 | + "Failed to distribute the desc type"); |
| 359 | + xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value(); |
| 360 | + auto distributedShape = newTDescType.getShape(); |
| 361 | + // use the base memref strides |
| 362 | + SmallVector<OpFoldResult> overwriteStrides = |
| 363 | + getAsIndexOpFoldResult(rewriter.getContext(), SmallVector<int64_t>{1, 1}); |
| 364 | + SmallVector<OpFoldResult> overwriteSizes = |
| 365 | + getAsIndexOpFoldResult(rewriter.getContext(), distributedShape); |
| 366 | + |
| 367 | + SmallVector<size_t> newRetIndices; |
| 368 | + gpu::WarpExecuteOnLane0Op newWarpOp = |
| 369 | + moveRegionToNewWarpOpAndAppendReturns( |
| 370 | + rewriter, warpOp, descOp.getSource(), descOp.getSourceType(), |
| 371 | + newRetIndices); |
| 372 | + |
| 373 | + rewriter.setInsertionPointAfter(newWarpOp); |
| 374 | + auto subview = rewriter.create<memref::SubViewOp>( |
| 375 | + newWarpOp.getLoc(), srcTypedVal, getAsOpFoldResult({offsetx, offsety}), |
| 376 | + overwriteSizes, overwriteStrides); |
| 377 | + subview.getSourceMutable().assign(newWarpOp.getResult(newRetIndices[0])); |
| 378 | + |
| 379 | + auto zero = rewriter.create<arith::ConstantIndexOp>(laneid.getLoc(), 0); |
| 380 | + auto newDescOp = rewriter.create<xegpu::CreateNdDescOp>( |
| 381 | + newWarpOp.getLoc(), newTDescType, subview, |
| 382 | + getAsOpFoldResult({zero, zero})); |
| 383 | + |
| 384 | + Value distributedVal = newWarpOp.getResult(operandIdx); |
| 385 | + rewriter.replaceAllUsesWith(distributedVal, newDescOp); |
| 386 | + |
| 387 | + return success(); |
| 388 | +} |
388 | 389 |
|
389 | 390 | void xegpu::populateXeGPUDistributePatterns(RewritePatternSet &patterns) { |
390 | | - // patterns.add<WarpOpTensorDescOp>(patterns.getContext()); |
| 391 | + patterns.add<WarpOpTensorDescOp>(patterns.getContext()); |
391 | 392 | // patterns.add<WarpOpStoreNd>(patterns.getContext()); |
392 | 393 | // patterns.add<WarpOpLoadNd>(patterns.getContext()); |
393 | 394 | } |
0 commit comments