|
11 | 11 | #include "mlir/Dialect/XeGPU/IR/XeGPU.h" |
12 | 12 | #include "mlir/IR/Builders.h" |
13 | 13 | #include "mlir/IR/BuiltinTypes.h" |
| 14 | +#include "mlir/IR/Diagnostics.h" |
14 | 15 | #include "mlir/IR/TypeUtilities.h" |
15 | 16 | #include "mlir/Support/LLVM.h" |
16 | 17 |
|
@@ -76,6 +77,39 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) { |
76 | 77 | kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH; |
77 | 78 | } |
78 | 79 |
|
| 80 | +// Helper to validate value shape of LoadNd and StoreNd ops. |
| 81 | +static LogicalResult |
| 82 | +isArgShapesValid(TensorDescType tdescTy, VectorType valueTy, |
| 83 | + ArrayRef<int64_t> adjustedTdescShape, |
| 84 | + function_ref<InFlightDiagnostic()> emitError) { |
| 85 | + auto sgMap = tdescTy.getSGMapAttr(); |
| 86 | + auto valueShape = valueTy.getShape(); |
| 87 | + // sg_map not present means IR is in SIMD mode. In this case value shape must |
| 88 | + // match adjusted tensor descriptor shape. |
| 89 | + if (!sgMap) |
| 90 | + return valueShape == adjustedTdescShape |
| 91 | + ? success() |
| 92 | + : emitError() |
| 93 | + << "Value shape " << makeString(valueShape) |
| 94 | + << " is not consistent with tensor descriptor " << tdescTy; |
| 95 | + |
| 96 | + // sg_map present means IR is in SIMT mode. In this case sg_map determines the |
| 97 | + // value shape. |
| 98 | + auto expectedValueShapeOrFailure = tdescTy.getDistributedVectorType(); |
| 99 | + if (failed(expectedValueShapeOrFailure)) |
| 100 | + return emitError() << "Failed to compute distributed vector shape for " |
| 101 | + "tensor descriptor " |
| 102 | + << tdescTy; |
| 103 | + |
| 104 | + return valueTy == expectedValueShapeOrFailure.value() |
| 105 | + ? success() |
| 106 | + : emitError() |
| 107 | + << "Result shape " << makeString(valueShape) |
| 108 | + << " is not consistent with distributed vector shape " |
| 109 | + << makeString(expectedValueShapeOrFailure.value().getShape()) |
| 110 | + << " for tensor descriptor " << tdescTy; |
| 111 | +} |
| 112 | + |
79 | 113 | //===----------------------------------------------------------------------===// |
80 | 114 | // XeGPU_CreateNdDescOp |
81 | 115 | //===----------------------------------------------------------------------===// |
@@ -282,31 +316,8 @@ LogicalResult LoadNdOp::verify() { |
282 | 316 | adjustedTdescShape.insert(it, array_len); |
283 | 317 | } |
284 | 318 |
|
285 | | - auto sgMap = tdescTy.getSGMapAttr(); |
286 | | - // sg_map not present means IR is in SIMD mode. In this case value shape must |
287 | | - // match adjusted tensor descriptor shape. |
288 | | - if (!sgMap) |
289 | | - return valueShape == adjustedTdescShape |
290 | | - ? success() |
291 | | - : emitOpError() |
292 | | - << "Result shape " << makeString(valueShape) |
293 | | - << " is not consistent with tensor descriptor " << tdescTy; |
294 | | - |
295 | | - // sg_map present means IR is in SIMT mode. In this case sg_map determines the |
296 | | - // value shape. |
297 | | - auto expectedValueShapeOrFailure = tdescTy.getDistributedVectorType(); |
298 | | - if (failed(expectedValueShapeOrFailure)) |
299 | | - return emitOpError() << "Failed to compute distributed vector shape for " |
300 | | - "tensor descriptor " |
301 | | - << tdescTy; |
302 | | - |
303 | | - return valueTy == expectedValueShapeOrFailure.value() |
304 | | - ? success() |
305 | | - : emitOpError() |
306 | | - << "Result shape " << makeString(valueShape) |
307 | | - << " is not consistent with distributed vector shape " |
308 | | - << makeString(expectedValueShapeOrFailure.value().getShape()) |
309 | | - << " for tensor descriptor " << tdescTy; |
| 319 | + return isArgShapesValid(tdescTy, valueTy, adjustedTdescShape, |
| 320 | + [&]() { return emitOpError(); }); |
310 | 321 | } |
311 | 322 |
|
312 | 323 | //===----------------------------------------------------------------------===// |
@@ -337,32 +348,8 @@ LogicalResult StoreNdOp::verify() { |
337 | 348 | auto tdescShape = getShapeOf(dstTy); |
338 | 349 | auto valueShape = getShapeOf(valTy); |
339 | 350 |
|
340 | | - auto sgMap = dstTy.getSGMapAttr(); |
341 | | - // sg_map not present means IR is in SIMD mode. In this case value shape must |
342 | | - // match adjusted tensor descriptor shape. |
343 | | - if (!sgMap) |
344 | | - return valueShape == tdescShape |
345 | | - ? success() |
346 | | - : emitOpError() |
347 | | - << "Result shape " << makeString(valueShape) |
348 | | - << " is not consistent with tensor descriptor shape " |
349 | | - << makeString(tdescShape); |
350 | | - |
351 | | - // sg_map present means IR is in SIMT mode. In this case sg_map determines the |
352 | | - // value shape. |
353 | | - auto expectedValueShapeOrFailure = dstTy.getDistributedVectorType(); |
354 | | - if (failed(expectedValueShapeOrFailure)) |
355 | | - return emitOpError() << "Failed to compute distributed vector shape for " |
356 | | - "tensor descriptor " |
357 | | - << dstTy; |
358 | | - |
359 | | - return valTy == expectedValueShapeOrFailure.value() |
360 | | - ? success() |
361 | | - : emitOpError() |
362 | | - << "Result shape " << makeString(valueShape) |
363 | | - << " is not consistent with distributed vector shape " |
364 | | - << makeString(expectedValueShapeOrFailure.value().getShape()) |
365 | | - << " for tensor descriptor " << dstTy; |
| 351 | + return isArgShapesValid(dstTy, valTy, tdescShape, |
| 352 | + [&]() { return emitOpError(); }); |
366 | 353 | } |
367 | 354 |
|
368 | 355 | //===----------------------------------------------------------------------===// |
|
0 commit comments