Skip to content

Commit ab59c46

Browse files
committed
save work
1 parent 4b5cffb commit ab59c46

File tree

3 files changed

+51
-64
lines changed

3 files changed

+51
-64
lines changed

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "llvm/ADT/TypeSwitch.h"
1515
#include "llvm/Support/Casting.h"
1616
#include "llvm/Support/LogicalResult.h"
17+
#include <cassert>
1718

1819
namespace mlir {
1920
namespace xegpu {
@@ -281,11 +282,11 @@ LogicalResult TensorDescType::verify(
281282
// Validate subgroup mapping rules for scattered tensors.
282283
// A work-item's slice of the tensor with shape [sg_size] or
283284
// [sg_size, chunk_size] will be [1] or [1, 32/element_ty_bit_width]
284-
// respectively, the mapping should reflect that.
285+
// respectively, the mapping should reflect that. This is because each
286+
// work item access data in 32 bit granularity.
285287
if (wiData[0] != 1)
286288
return emitError()
287289
<< "cannot map over non-contiguous scattered row elements";
288-
289290
if (wiData[1] != (32 / elementType.getIntOrFloatBitWidth()))
290291
return emitError() << "work item data mapping must match the number of "
291292
"contiguous elements";
@@ -351,14 +352,13 @@ FailureOr<VectorType> TensorDescType::getDistributedVectorType() {
351352
}
352353

353354
// Case 1: regular loads/stores
354-
auto scatterAttr =
355-
llvm::dyn_cast_if_present<ScatterTensorDescAttr>(getEncoding());
355+
auto scatterAttr = getEncodingAsScatterTensorDescAttr();
356356
if (scatterAttr) {
357357
auto chunkSize = scatterAttr.getChunkSize().getInt();
358-
// Check if the first dimension of the tensor descriptor shape is
358+
// Verify if the first dimension of the tensor descriptor shape is
359359
// distributable.
360-
if (tdescShape[0] % (wiLayout[0]) != 0)
361-
return failure();
360+
assert(tdescShape[0] % (wiLayout[0]) == 0 &&
361+
"tensor descriptor shape is not distributable");
362362
if (chunkSize > 1)
363363
return VectorType::get({chunkSize / wiDataSize, wiDataSize},
364364
getElementType());
@@ -369,17 +369,17 @@ FailureOr<VectorType> TensorDescType::getDistributedVectorType() {
369369
// Tensor descriptor shape can be 1D. For the 1D case, outer dims of wiData
370370
// and wiLayout must be 1.
371371
if (tdescShape.size() == 1) {
372-
if (wiData[0] != 1 || wiLayout[0] != 1)
373-
return failure();
372+
assert((wiData[0] == 1 && wiLayout[0] == 1) &&
373+
"wi_data[0] and wi_layout[0] must be 1 for 1D tensor descriptor");
374374
wiData = {wiData[1]};
375375
wiLayout = {wiLayout[1]};
376376
}
377377
// Check if the tensor descriptor shape is distributable.
378378
int64_t tensorSize = 1;
379379
for (auto [tdescDim, wiDim, wiDataDim] :
380380
llvm::zip_equal(tdescShape, wiLayout, wiData)) {
381-
if (tdescDim % (wiDim * wiDataDim) != 0)
382-
return failure();
381+
assert((tdescDim % (wiDim * wiDataDim) == 0) &&
382+
"tensor descriptor shape is not distributable");
383383
tensorSize *= tdescDim;
384384
}
385385
// tensorSize must be adjusted for array_length.

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 38 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
1212
#include "mlir/IR/Builders.h"
1313
#include "mlir/IR/BuiltinTypes.h"
14+
#include "mlir/IR/Diagnostics.h"
1415
#include "mlir/IR/TypeUtilities.h"
1516
#include "mlir/Support/LLVM.h"
1617

@@ -76,6 +77,39 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
7677
kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
7778
}
7879

80+
// Helper to validate value shape of LoadNd and StoreNd ops.
81+
static LogicalResult
82+
isArgShapesValid(TensorDescType tdescTy, VectorType valueTy,
83+
ArrayRef<int64_t> adjustedTdescShape,
84+
function_ref<InFlightDiagnostic()> emitError) {
85+
auto sgMap = tdescTy.getSGMapAttr();
86+
auto valueShape = valueTy.getShape();
87+
// sg_map not present means IR is in SIMD mode. In this case value shape must
88+
// match adjusted tensor descriptor shape.
89+
if (!sgMap)
90+
return valueShape == adjustedTdescShape
91+
? success()
92+
: emitError()
93+
<< "Value shape " << makeString(valueShape)
94+
<< " is not consistent with tensor descriptor " << tdescTy;
95+
96+
// sg_map present means IR is in SIMT mode. In this case sg_map determines the
97+
// value shape.
98+
auto expectedValueShapeOrFailure = tdescTy.getDistributedVectorType();
99+
if (failed(expectedValueShapeOrFailure))
100+
return emitError() << "Failed to compute distributed vector shape for "
101+
"tensor descriptor "
102+
<< tdescTy;
103+
104+
return valueTy == expectedValueShapeOrFailure.value()
105+
? success()
106+
: emitError()
107+
<< "Result shape " << makeString(valueShape)
108+
<< " is not consistent with distributed vector shape "
109+
<< makeString(expectedValueShapeOrFailure.value().getShape())
110+
<< " for tensor descriptor " << tdescTy;
111+
}
112+
79113
//===----------------------------------------------------------------------===//
80114
// XeGPU_CreateNdDescOp
81115
//===----------------------------------------------------------------------===//
@@ -282,31 +316,8 @@ LogicalResult LoadNdOp::verify() {
282316
adjustedTdescShape.insert(it, array_len);
283317
}
284318

285-
auto sgMap = tdescTy.getSGMapAttr();
286-
// sg_map not present means IR is in SIMD mode. In this case value shape must
287-
// match adjusted tensor descriptor shape.
288-
if (!sgMap)
289-
return valueShape == adjustedTdescShape
290-
? success()
291-
: emitOpError()
292-
<< "Result shape " << makeString(valueShape)
293-
<< " is not consistent with tensor descriptor " << tdescTy;
294-
295-
// sg_map present means IR is in SIMT mode. In this case sg_map determines the
296-
// value shape.
297-
auto expectedValueShapeOrFailure = tdescTy.getDistributedVectorType();
298-
if (failed(expectedValueShapeOrFailure))
299-
return emitOpError() << "Failed to compute distributed vector shape for "
300-
"tensor descriptor "
301-
<< tdescTy;
302-
303-
return valueTy == expectedValueShapeOrFailure.value()
304-
? success()
305-
: emitOpError()
306-
<< "Result shape " << makeString(valueShape)
307-
<< " is not consistent with distributed vector shape "
308-
<< makeString(expectedValueShapeOrFailure.value().getShape())
309-
<< " for tensor descriptor " << tdescTy;
319+
return isArgShapesValid(tdescTy, valueTy, adjustedTdescShape,
320+
[&]() { return emitOpError(); });
310321
}
311322

312323
//===----------------------------------------------------------------------===//
@@ -337,32 +348,8 @@ LogicalResult StoreNdOp::verify() {
337348
auto tdescShape = getShapeOf(dstTy);
338349
auto valueShape = getShapeOf(valTy);
339350

340-
auto sgMap = dstTy.getSGMapAttr();
341-
// sg_map not present means IR is in SIMD mode. In this case value shape must
342-
// match adjusted tensor descriptor shape.
343-
if (!sgMap)
344-
return valueShape == tdescShape
345-
? success()
346-
: emitOpError()
347-
<< "Result shape " << makeString(valueShape)
348-
<< " is not consistent with tensor descriptor shape "
349-
<< makeString(tdescShape);
350-
351-
// sg_map present means IR is in SIMT mode. In this case sg_map determines the
352-
// value shape.
353-
auto expectedValueShapeOrFailure = dstTy.getDistributedVectorType();
354-
if (failed(expectedValueShapeOrFailure))
355-
return emitOpError() << "Failed to compute distributed vector shape for "
356-
"tensor descriptor "
357-
<< dstTy;
358-
359-
return valTy == expectedValueShapeOrFailure.value()
360-
? success()
361-
: emitOpError()
362-
<< "Result shape " << makeString(valueShape)
363-
<< " is not consistent with distributed vector shape "
364-
<< makeString(expectedValueShapeOrFailure.value().getShape())
365-
<< " for tensor descriptor " << dstTy;
351+
return isArgShapesValid(dstTy, valTy, tdescShape,
352+
[&]() { return emitOpError(); });
366353
}
367354

368355
//===----------------------------------------------------------------------===//

mlir/test/Dialect/XeGPU/invalid.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ func.func @test_load_nd_sg_map(%src: memref<24x32xf32>) {
105105
func.func @test_load_nd_vc_6(%src: memref<24x32xf32>) {
106106
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
107107
!xegpu.tensor_desc<8x16xf32>
108-
// expected-error@+1 {{Result shape [8, 1] is not consistent with tensor descriptor}}
108+
// expected-error@+1 {{Value shape [8, 1] is not consistent with tensor descriptor}}
109109
%2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
110110
l2_hint = #xegpu.cache_hint<uncached>}>
111111
: !xegpu.tensor_desc<8x16xf32> -> vector<8x1xf32>
@@ -157,7 +157,7 @@ func.func @test_store_nd_sg_map(%dst: memref<24x32xf32>, %data: vector<2xf32>) {
157157
func.func @test_store_nd_vc_5(%dst: memref<24x32xf32>, %data: vector<8x1xf32>) {
158158
%1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
159159
!xegpu.tensor_desc<8x16xf32>
160-
// expected-error@+1 {{Result shape [8, 1] is not consistent with tensor descriptor shape [8, 16]}}
160+
// expected-error@+1 {{Value shape [8, 1] is not consistent with tensor descriptor}}
161161
xegpu.store_nd %data, %1 : vector<8x1xf32>, !xegpu.tensor_desc<8x16xf32>
162162
return
163163
}

0 commit comments

Comments
 (0)