Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/imex-version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
25123cc3692fdbcf837510f39de47ff353d482fc
0a6d2901990183ecdbab9240dbd8be92036d9c20
2 changes: 1 addition & 1 deletion cmake/imex.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ if (NOT DEFINED IMEX_INCLUDES)

# TODO: Change to main https://github.com/intel/mlir-extensions when all the
# required functionality is merged.
set(IMEX_URL https://github.com/intel/mlir-extensions)
set(IMEX_URL https://github.com/dchigarev/mlir-extensions)
gc_fetch_content(imex "${IMEX_HASH}" "${IMEX_URL}"
SET IMEX_CHECK_LLVM_VERSION=ON IMEX_ENABLE_L0_RUNTIME=0
)
Expand Down
2 changes: 1 addition & 1 deletion cmake/llvm-version-imex.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3191587666aa3d1e53966bc8876614c7197fac4f
add6b2f35f2bcf1f59a2ab2d5b3dab124fe0895a
2 changes: 1 addition & 1 deletion cmake/llvm-version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
f6a756f35a4d0719a96b4e214905369d565d87da
add6b2f35f2bcf1f59a2ab2d5b3dab124fe0895a
14 changes: 8 additions & 6 deletions lib/gc/Transforms/GPU/LinalgToXeGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,9 @@ loadNdDescTiles(PatternRewriter &rewriter, Location loc, ValueRange loadTiles,
if (vnniConf) {
vecLoadType = getVnniVector(tileType.getShape(), tileType.getElementType(),
*vnniConf);
packedAttr = mlir::UnitAttr::get(rewriter.getContext());
if (!transpose_bit) {
packedAttr = mlir::UnitAttr::get(rewriter.getContext());
}
Comment on lines -767 to +769
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specifying both transpose_bit and packed in xegpu.load_nd is no longer supported. If transpose_bit is specified then the op already behaves as if packed was set.

}
SmallVector<Value> loadVec;
for (auto tile : loadTiles) {
Expand Down Expand Up @@ -1165,7 +1167,6 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
if (vnniFactor == -1)
return failure();

VnniConfig vnniConfA{.vnniFactor = vnniFactor, .vnniAxis = 1};
VnniConfig vnniConfB{.vnniFactor = vnniFactor, .vnniAxis = 0};
Comment on lines -1168 to 1170
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xegpu.dpas no longer supports A arguments to be a 3D vector


// Load A sub-tiles.
Expand Down Expand Up @@ -1212,9 +1213,9 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
}

// Extract DPAS tiles from loaded sub-tiles.
TilesArray dpasVecA = extractVecSubTiles(rewriter, loc, loadVecA,
{dimM, kTile}, tileTypeA.getShape(),
{dpasTileM, dpasTileK}, vnniConfA);
TilesArray dpasVecA =
extractVecSubTiles(rewriter, loc, loadVecA, {dimM, kTile},
tileTypeA.getShape(), {dpasTileM, dpasTileK});
TilesArray dpasVecB = extractVecSubTiles(rewriter, loc, loadVecB,
{kTile, dimN}, tileTypeB.getShape(),
{dpasTileK, dpasTileN}, vnniConfB);
Expand Down Expand Up @@ -1629,7 +1630,8 @@ struct LinalgToXeGPU : public gc::impl::LinalgToXeGPUBase<LinalgToXeGPU> {
using LinalgToXeGPUBase::LinalgToXeGPUBase;

void runOnOperation() override {
LinalgToXeGPUOptions options{kTile, stages, dpasTile};
LinalgToXeGPUOptions options{
kTile, stages, SmallVector<int64_t>(dpasTile.begin(), dpasTile.end())};
Comment on lines -1632 to +1634
Copy link
Contributor Author

@dchigarev dchigarev Oct 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the conversion explicitly

error: could not convert ‘dpasTile’ from ‘mlir::Pass::ListOption<long int>’ to ‘llvm::SmallVector<long int>’


// Run GEMM pattern first to allow fusion with its consumers.
RewritePatternSet gemmPatterns(&getContext());
Expand Down
4 changes: 2 additions & 2 deletions lib/gc/Transforms/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@ void populateTensorPasses(mlir::OpPassManager &pm) {
// scf + arith + math + vector + tensor + linalg.brgemm
void populateVectorPasses(mlir::OpPassManager &pm) {
// Do promotion for math / arith ops
pm.addNestedPass<func::FuncOp>(math::createMathLegalizeToF32());
pm.addNestedPass<func::FuncOp>(math::createMathExtendToSupportedTypes());
// sourceTypeStrs can be extended
arith::ArithEmulateUnsupportedFloatsOptions options;
std::array<std::string, 1> typeStr = {"bf16"};
SmallVector<std::string, 1> typeStr = {"bf16"};
options.sourceTypeStrs = typeStr;
options.targetTypeStr = "f32";
pm.addNestedPass<func::FuncOp>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ module {
// Tile B is already in the correct shape.
// CHECK: %[[vA_flat:.+]] = vector.shape_cast %[[vA]] : vector<32x16xf16> to vector<512xf16>
// CHECK: %[[vA_dpas_flat:.+]] = vector.extract_strided_slice{{.*}}: vector<512xf16> to vector<128xf16>
// CHECK: %[[vA_dpas:.+]] = vector.shape_cast %[[vA_dpas_flat]] : vector<128xf16> to vector<8x8x2xf16>
// CHECK: %[[vA_dpas:.+]] = vector.shape_cast %[[vA_dpas_flat]] : vector<128xf16> to vector<8x16xf16>
// CHECK-COUNT-3: vector.extract_strided_slice

// Perform DPAS computation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ module {
// Tile B is already in the correct shape.
// CHECK: %[[vA_flat:.+]] = vector.shape_cast %[[vA]] : vector<32x16xf16> to vector<512xf16>
// CHECK: %[[vA_dpas_flat:.+]] = vector.extract_strided_slice{{.*}}: vector<512xf16> to vector<128xf16>
// CHECK: %[[vA_dpas:.+]] = vector.shape_cast %[[vA_dpas_flat]] : vector<128xf16> to vector<8x8x2xf16>
// CHECK: %[[vA_dpas:.+]] = vector.shape_cast %[[vA_dpas_flat]] : vector<128xf16> to vector<8x16xf16>
// CHECK-COUNT-3: vector.extract_strided_slice

// Perform DPAS computation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ module {
// Tile B is already in the correct shape.
// CHECK: %[[vA_flat:.+]] = vector.shape_cast %[[vA]] : vector<32x16xf16> to vector<512xf16>
// CHECK: %[[vA_dpas_flat:.+]] = vector.extract_strided_slice{{.*}}: vector<512xf16> to vector<128xf16>
// CHECK: %[[vA_dpas:.+]] = vector.shape_cast %[[vA_dpas_flat]] : vector<128xf16> to vector<8x8x2xf16>
// CHECK: %[[vA_dpas:.+]] = vector.shape_cast %[[vA_dpas_flat]] : vector<128xf16> to vector<8x16xf16>
// CHECK-COUNT-3: vector.extract_strided_slice

// Perform DPAS computation.
Expand Down
2 changes: 1 addition & 1 deletion test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-dpas.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func.func @matmul(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: mem
// Tile B is already in the correct shape.
// CHECK: %[[vA_flat:.+]] = vector.shape_cast %[[vA]] : vector<32x16xf16> to vector<512xf16>
// CHECK: %[[vA_dpas_flat:.+]] = vector.extract_strided_slice{{.*}}: vector<512xf16> to vector<128xf16>
// CHECK: %[[vA_dpas:.+]] = vector.shape_cast %[[vA_dpas_flat]] : vector<128xf16> to vector<8x8x2xf16>
// CHECK: %[[vA_dpas:.+]] = vector.shape_cast %[[vA_dpas_flat]] : vector<128xf16> to vector<8x16xf16>
// CHECK-COUNT-3: vector.extract_strided_slice

// Perform DPAS computation.
Expand Down
2 changes: 1 addition & 1 deletion test/mlir/test/gc/Transforms/bf16Legalization.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: gc-opt %s --math-legalize-to-f32 --arith-emulate-unsupported-floats="source-types=bf16 target-type=f32" --canonicalize | FileCheck %s
// RUN: gc-opt %s --math-extend-to-supported-types --arith-emulate-unsupported-floats="source-types=bf16 target-type=f32" --canonicalize | FileCheck %s

// CHECK-LABEL: @sin
// CHECK-SAME: ([[ARG0:%.+]]: bf16)
Expand Down