Skip to content

Commit 62be14d

Browse files
committed
Update PR to use the new uArch
1 parent cecfef0 commit 62be14d

File tree

9 files changed

+359
-146
lines changed

9 files changed

+359
-146
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -379,28 +379,28 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
379379
);
380380

381381
let builders = [
382-
AttrBuilder<(ins "llvm::ArrayRef<int32_t>": $lane_layout,
382+
AttrBuilder<(ins "llvm::ArrayRef<int32_t>": $inst_data,
383+
"llvm::ArrayRef<int32_t>": $lane_layout,
383384
"llvm::ArrayRef<int32_t>": $lane_data),
384385
[{
385386
auto sg_layout = DenseI32ArrayAttr();
386387
auto sg_data = DenseI32ArrayAttr();
387-
auto inst_data = DenseI32ArrayAttr();
388388
auto order = DenseI32ArrayAttr();
389-
return $_get($_ctxt, sg_layout, sg_data, inst_data,
389+
return $_get($_ctxt, sg_layout, sg_data,
390+
DenseI32ArrayAttr::get($_ctxt, inst_data),
390391
DenseI32ArrayAttr::get($_ctxt, lane_layout),
391392
DenseI32ArrayAttr::get($_ctxt, lane_data), order);
392393
}]>,
393394
AttrBuilder<(ins "llvm::ArrayRef<int32_t>": $lane_layout,
394-
"llvm::ArrayRef<int32_t>": $lane_data,
395-
"llvm::ArrayRef<int32_t>": $order),
395+
"llvm::ArrayRef<int32_t>": $lane_data),
396396
[{
397-
return $_get($_ctxt,
398-
/*sg_layout =*/ nullptr,
399-
/*sg_data =*/ nullptr,
400-
/*inst_data =*/ nullptr,
397+
auto sg_layout = DenseI32ArrayAttr();
398+
auto sg_data = DenseI32ArrayAttr();
399+
auto inst_data = DenseI32ArrayAttr();
400+
auto order = DenseI32ArrayAttr();
401+
return $_get($_ctxt, sg_layout, sg_data, inst_data,
401402
DenseI32ArrayAttr::get($_ctxt, lane_layout),
402-
DenseI32ArrayAttr::get($_ctxt, lane_data),
403-
DenseI32ArrayAttr::get($_ctxt, order));
403+
DenseI32ArrayAttr::get($_ctxt, lane_data), order);
404404
}]>,
405405
AttrBuilder<(ins "DenseI32ArrayAttr": $lane_layout,
406406
"DenseI32ArrayAttr": $lane_data,

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h

Lines changed: 0 additions & 30 deletions
This file was deleted.

mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,12 @@ def XeGPUPropagateLayout : Pass<"xegpu-propagate-layout"> {
4343
let options = [Option<
4444
"printOnly", "print-analysis-only", "bool",
4545
/*default=*/"false",
46-
"Print the result of layout propagation analysis and exit.">];
46+
"Print the result of layout propagation analysis and exit.">,
47+
Option<
48+
"layoutKind", "layout-kind", "std::string",
49+
/*default=*/"\"lane\"",
50+
"Propagate a `sg` / `inst` / `lane` level of xegpu layouts.">
51+
];
4752
}
4853

4954
def XeGPUWgToSgDistribute : Pass<"xegpu-wg-to-sg-distribute"> {

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
#include "mlir/Dialect/Index/IR/IndexOps.h"
1212
#include "mlir/Dialect/Utils/IndexingUtils.h"
1313
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
14-
#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"
14+
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
1515
#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
1616
#include "mlir/IR/Builders.h"
1717
#include "mlir/IR/DialectImplementation.h"
@@ -229,8 +229,10 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
229229
}
230230

231231
if (inst_data && lane_layout && inst_data.size() != lane_layout.size()) {
232-
return emitError()
233-
<< "expected inst_data and lane_layout to have the same rank";
232+
return emitError() << "expected inst_data and lane_layout to have the same "
233+
"rank, got inst_data "
234+
<< inst_data.size() << ", lane_layout "
235+
<< lane_layout.size();
234236
}
235237

236238
// sg_data is optional for Workgroup layout, but its presence requires
@@ -568,10 +570,10 @@ TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
568570

569571
// for gather and scatter ops, Low-precision types are packed in 32-bit units.
570572
unsigned bitWidth = elementType.getIntOrFloatBitWidth();
571-
int chunkAlignmentFactor =
572-
bitWidth < targetinfo::packedSizeInBitsForGatherScatter
573-
? targetinfo::packedSizeInBitsForGatherScatter / bitWidth
574-
: 1;
573+
constexpr int packingBitSizeGatherScatter{32};
574+
int chunkAlignmentFactor = bitWidth < packingBitSizeGatherScatter
575+
? packingBitSizeGatherScatter / bitWidth
576+
: 1;
575577
auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
576578
if (scatterAttr) {
577579
int64_t chunkSize = scatterAttr.getChunkSizeAsInt();

0 commit comments

Comments
 (0)