Skip to content

Commit e1c8963

Browse files
committed
save work
1 parent ff24db0 commit e1c8963

File tree

3 files changed

+12
-52
lines changed

3 files changed

+12
-52
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -608,8 +608,8 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
608608
return getElementTypeOrSelf(type);
609609
}
610610

611-
Type getValueType() {
612-
return getValue().getType();
611+
VectorType getValueType() {
612+
return llvm::dyn_cast<VectorType>(getValue().getType());
613613
}
614614

615615
Type getMaskType() {
@@ -668,8 +668,8 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
668668
return getTensorDesc().getType();
669669
}
670670

671-
Type getValueType() {
672-
return getValue().getType();
671+
VectorType getValueType() {
672+
return llvm::dyn_cast<VectorType>(getValue().getType());
673673
}
674674

675675
Type getMaskType() {

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 4 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -500,28 +500,8 @@ LogicalResult LoadGatherOp::verify() {
500500
transpose({1, 0}, tdescShape);
501501
}
502502

503-
auto sgMap = tdescTy.getSGMapAttr();
504-
// In SIMD mode, sg_map is not present. In this case, value shape must match
505-
// the tensor descriptor shape.
506-
if (!sgMap)
507-
return valueShape == tdescShape
508-
? success()
509-
: emitOpError("Unexpected result shape")
510-
<< "(Expected shape: " << makeString(tdescShape)
511-
<< ", Given shape: " << makeString(valueShape) << ").\n";
512-
// In SIMT mode, sg_map, wi_data, and chunk_size determine the value shape.
513-
auto distributedVectorShapeOrFailure = tdescTy.getDistributedVectorType();
514-
if (failed(distributedVectorShapeOrFailure))
515-
return emitOpError("Failed to compute distributed vector shape for "
516-
"tensor descriptor ")
517-
<< tdescTy;
518-
if (cast<VectorType>(valueTy) != distributedVectorShapeOrFailure.value())
519-
return emitOpError("Unexpected result shape")
520-
<< "(Expected shape: "
521-
<< makeString(distributedVectorShapeOrFailure.value().getShape())
522-
<< ", Given shape: " << makeString(valueShape) << ").\n";
523-
524-
return success();
503+
return isArgShapesValid(tdescTy, valueTy, tdescShape,
504+
[&]() { return emitOpError(); });
525505
}
526506

527507
//===----------------------------------------------------------------------===//
@@ -555,28 +535,8 @@ LogicalResult StoreScatterOp::verify() {
555535
transpose({1, 0}, tdescShape);
556536
}
557537

558-
auto sgMap = tdescTy.getSGMapAttr();
559-
// In SIMD mode, sg_map is not present. In this case, value shape must match
560-
// the tensor descriptor shape.
561-
if (!sgMap)
562-
return valueShape == tdescShape
563-
? success()
564-
: emitOpError("Unexpected value shape")
565-
<< "(Expected shape: " << makeString(tdescShape)
566-
<< ", Given shape: " << makeString(valueShape) << ").\n";
567-
// In SIMT mode, sg_map, wi_data, and chunk_size determine the value shape.
568-
auto distributedVectorShapeOrFailure = tdescTy.getDistributedVectorType();
569-
if (failed(distributedVectorShapeOrFailure))
570-
return emitOpError("Failed to compute distributed vector shape for "
571-
"tensor descriptor ")
572-
<< tdescTy;
573-
if (cast<VectorType>(valueTy) != distributedVectorShapeOrFailure.value())
574-
return emitOpError("Unexpected value shape")
575-
<< "(Expected shape: "
576-
<< makeString(distributedVectorShapeOrFailure.value().getShape())
577-
<< ", Given shape: " << makeString(valueShape) << ").\n";
578-
579-
return success();
538+
return isArgShapesValid(tdescTy, valueTy, tdescShape,
539+
[&]() { return emitOpError(); });
580540
}
581541

582542
//===----------------------------------------------------------------------===//

mlir/test/Dialect/XeGPU/invalid.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ func.func @test_load_gather_sg_map_1(%src: ui64) {
273273
%0 = arith.constant dense<1>: vector<4xi1>
274274
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
275275
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
276-
// expected-error@+1 {{Unexpected result shape(Expected shape: [2, 1], Given shape: [1, 2])}}
276+
// expected-error@+1 {{Result shape [1, 2] is not consistent with distributed vector shape [2, 1] for tensor descriptor}}
277277
%2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>, vector<4xi1> -> vector<1x2xf32>
278278
return
279279
}
@@ -283,7 +283,7 @@ func.func @test_load_gather_sg_map_2(%src: ui64) {
283283
%0 = arith.constant dense<1>: vector<4xi1>
284284
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
285285
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
286-
// expected-error@+1 {{Unexpected result shape(Expected shape: [2, 1], Given shape: [2])}}
286+
// expected-error@+1 {{esult shape [2] is not consistent with distributed vector shape [2, 1] for tensor descriptor}}
287287
%2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>, vector<4xi1> -> vector<2xf32>
288288
return
289289
}
@@ -295,7 +295,7 @@ func.func @test_store_scatter_sg_map_1(%src: ui64) {
295295
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
296296
%val = arith.constant dense<2.9>: vector<1x2xf32>
297297
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
298-
// expected-error@+1 {{Unexpected value shape(Expected shape: [2, 1], Given shape: [1, 2])}}
298+
// expected-error@+1 {{Result shape [1, 2] is not consistent with distributed vector shape [2, 1] for tensor descriptor}}
299299
xegpu.store %val, %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, transpose}> : vector<1x2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>, vector<4xi1>
300300
return
301301
}
@@ -306,7 +306,7 @@ func.func @test_store_scatter_sg_map_2(%src: ui64) {
306306
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
307307
%val = arith.constant dense<2.9>: vector<2xf32>
308308
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
309-
// expected-error@+1 {{Unexpected value shape(Expected shape: [2, 1], Given shape: [2])}}
309+
// expected-error@+1 {{esult shape [2] is not consistent with distributed vector shape [2, 1] for tensor descriptor}}
310310
xegpu.store %val, %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, transpose}> : vector<2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>, vector<4xi1>
311311
return
312312
}

0 commit comments

Comments
 (0)