Skip to content

Commit 21571cf

Browse files
authored
[HLSL][Matrix] Support row and column indexing modes for MatrixSubscriptExpr (#171564)
fixes #167617 In DXC HLSL supports different indexing modes via codegen for its equivalent of the MatrixSubscriptExpr when the /Zpr and /Zpc flags are used see: https://godbolt.org/z/bz5Y5WG36. This change modifies EmitMatrixSubscriptExpr to consider the MatrixRowMajor/MatrixColMajor Layout flags before generating an index. Similarly it introduces `createRowMajorIndex` and `createColumnMajorIndex` in `MatrixBuilder.h` for use in `VisitMatrixSubscriptExpr`.
1 parent cb56a91 commit 21571cf

File tree

6 files changed

+198
-12
lines changed

6 files changed

+198
-12
lines changed

clang/lib/CodeGen/CGExpr.cpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2482,7 +2482,10 @@ RValue CodeGenFunction::EmitLoadOfLValue(LValue LV, SourceLocation Loc) {
24822482

24832483
for (unsigned Col = 0; Col < NumCols; ++Col) {
24842484
llvm::Value *ColIdx = llvm::ConstantInt::get(Row->getType(), Col);
2485-
llvm::Value *EltIndex = MB.CreateIndex(Row, ColIdx, NumRows);
2485+
bool IsMatrixRowMajor = getLangOpts().getDefaultMatrixMemoryLayout() ==
2486+
LangOptions::MatrixMemoryLayout::MatrixRowMajor;
2487+
llvm::Value *EltIndex =
2488+
MB.CreateIndex(Row, ColIdx, NumRows, NumCols, IsMatrixRowMajor);
24862489
llvm::Value *Elt = Builder.CreateExtractElement(MatrixVec, EltIndex);
24872490
llvm::Value *Lane = llvm::ConstantInt::get(Builder.getInt32Ty(), Col);
24882491
Result = Builder.CreateInsertElement(Result, Elt, Lane);
@@ -2733,7 +2736,10 @@ void CodeGenFunction::EmitStoreThroughLValue(RValue Src, LValue Dst,
27332736

27342737
for (unsigned Col = 0; Col < NumCols; ++Col) {
27352738
llvm::Value *ColIdx = llvm::ConstantInt::get(Row->getType(), Col);
2736-
llvm::Value *EltIndex = MB.CreateIndex(Row, ColIdx, NumRows);
2739+
bool IsMatrixRowMajor = getLangOpts().getDefaultMatrixMemoryLayout() ==
2740+
LangOptions::MatrixMemoryLayout::MatrixRowMajor;
2741+
llvm::Value *EltIndex =
2742+
MB.CreateIndex(Row, ColIdx, NumRows, NumCols, IsMatrixRowMajor);
27372743
llvm::Value *Lane = llvm::ConstantInt::get(Builder.getInt32Ty(), Col);
27382744
llvm::Value *NewElt = Builder.CreateExtractElement(RowVal, Lane);
27392745
MatrixVec = Builder.CreateInsertElement(MatrixVec, NewElt, EltIndex);
@@ -4976,12 +4982,15 @@ LValue CodeGenFunction::EmitMatrixSubscriptExpr(const MatrixSubscriptExpr *E) {
49764982
// Extend or truncate the index type to 32 or 64-bits if needed.
49774983
llvm::Value *RowIdx = EmitMatrixIndexExpr(E->getRowIdx());
49784984
llvm::Value *ColIdx = EmitMatrixIndexExpr(E->getColumnIdx());
4979-
4980-
llvm::Value *NumRows = Builder.getIntN(
4981-
RowIdx->getType()->getScalarSizeInBits(),
4982-
E->getBase()->getType()->castAs<ConstantMatrixType>()->getNumRows());
4985+
llvm::MatrixBuilder MB(Builder);
4986+
const auto *MatrixTy = E->getBase()->getType()->castAs<ConstantMatrixType>();
4987+
unsigned NumCols = MatrixTy->getNumColumns();
4988+
unsigned NumRows = MatrixTy->getNumRows();
4989+
bool IsMatrixRowMajor = getLangOpts().getDefaultMatrixMemoryLayout() ==
4990+
LangOptions::MatrixMemoryLayout::MatrixRowMajor;
49834991
llvm::Value *FinalIdx =
4984-
Builder.CreateAdd(Builder.CreateMul(ColIdx, NumRows), RowIdx);
4992+
MB.CreateIndex(RowIdx, ColIdx, NumRows, NumCols, IsMatrixRowMajor);
4993+
49854994
return LValue::MakeMatrixElt(
49864995
MaybeConvertMatrixAddress(Base.getAddress(), *this), FinalIdx,
49874996
E->getBase()->getType(), Base.getBaseInfo(), TBAAAccessInfo());

clang/lib/CodeGen/CGExprScalar.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2136,7 +2136,10 @@ Value *ScalarExprEmitter::VisitMatrixSingleSubscriptExpr(
21362136

21372137
for (unsigned Col = 0; Col != NumColumns; ++Col) {
21382138
Value *ColVal = llvm::ConstantInt::get(RowIdx->getType(), Col);
2139-
Value *EltIdx = MB.CreateIndex(RowIdx, ColVal, NumRows, "matrix_row_idx");
2139+
bool IsMatrixRowMajor = CGF.getLangOpts().getDefaultMatrixMemoryLayout() ==
2140+
LangOptions::MatrixMemoryLayout::MatrixRowMajor;
2141+
Value *EltIdx = MB.CreateIndex(RowIdx, ColVal, NumRows, NumColumns,
2142+
IsMatrixRowMajor, "matrix_row_idx");
21402143
Value *Elt =
21412144
Builder.CreateExtractElement(FlatMatrix, EltIdx, "matrix_elem");
21422145
Value *Lane = llvm::ConstantInt::get(Builder.getInt32Ty(), Col);
@@ -2155,9 +2158,15 @@ Value *ScalarExprEmitter::VisitMatrixSubscriptExpr(MatrixSubscriptExpr *E) {
21552158
Value *ColumnIdx = CGF.EmitMatrixIndexExpr(E->getColumnIdx());
21562159

21572160
const auto *MatrixTy = E->getBase()->getType()->castAs<ConstantMatrixType>();
2158-
unsigned NumRows = MatrixTy->getNumRows();
21592161
llvm::MatrixBuilder MB(Builder);
2160-
Value *Idx = MB.CreateIndex(RowIdx, ColumnIdx, NumRows);
2162+
2163+
Value *Idx;
2164+
unsigned NumCols = MatrixTy->getNumColumns();
2165+
unsigned NumRows = MatrixTy->getNumRows();
2166+
bool IsMatrixRowMajor = CGF.getLangOpts().getDefaultMatrixMemoryLayout() ==
2167+
LangOptions::MatrixMemoryLayout::MatrixRowMajor;
2168+
Idx = MB.CreateIndex(RowIdx, ColumnIdx, NumRows, NumCols, IsMatrixRowMajor);
2169+
21612170
if (CGF.CGM.getCodeGenOpts().OptimizationLevel > 0)
21622171
MB.CreateIndexAssumption(Idx, MatrixTy->getNumElementsFlattened());
21632172

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// RUN: %clang_cc1 -fenable-matrix -fmatrix-memory-layout=row-major -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s --check-prefixes=CHECK,ROW-CHECK
2+
// RUN: %clang_cc1 -fenable-matrix -fmatrix-memory-layout=column-major -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s --check-prefixes=CHECK,COL-CHECK
3+
// RUN: %clang_cc1 -fenable-matrix -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s --check-prefixes=CHECK,COL-CHECK
4+
5+
typedef float fx2x3_t __attribute__((matrix_type(2, 3)));
6+
float Out[6];
7+
8+
fx2x3_t gM;
9+
10+
void binaryOpMatrixSubscriptExpr(int index, fx2x3_t M) {
11+
// CHECK-LABEL: binaryOpMatrixSubscriptExpr
12+
// CHECK: %row = alloca i32, align 4
13+
// CHECK: %col = alloca i32, align 4
14+
// CHECK: [[row_load:%.*]] = load i32, ptr %row, align 4
15+
// CHECK-NEXT: [[row_load_zext:%.*]] = zext i32 [[row_load]] to i64
16+
// CHECK-NEXT: [[col_load:%.*]] = load i32, ptr %col, align 4
17+
// CHECK-NEXT: [[col_load_zext:%.*]] = zext i32 [[col_load]] to i64
18+
// COL-CHECK-NEXT: [[col_offset:%.*]] = mul i64 [[col_load_zext]], 2
19+
// COL-CHECK-NEXT: [[col_major_index:%.*]] = add i64 [[col_offset]], [[row_load_zext]]
20+
// ROW-CHECK-NEXT: [[row_offset:%.*]] = mul i64 [[row_load_zext]], 3
21+
// ROW-CHECK-NEXT: [[row_major_index:%.*]] = add i64 [[row_offset]], [[col_load_zext]]
22+
// CHECK-NEXT: [[matrix_as_vec:%.*]] = load <6 x float>, ptr %M.addr, align 4
23+
// COL-CHECK-NEXT: %matrixext = extractelement <6 x float> [[matrix_as_vec]], i64 [[col_major_index]]
24+
// ROW-CHECK-NEXT: %matrixext = extractelement <6 x float> [[matrix_as_vec]], i64 [[row_major_index]]
25+
const unsigned int COLS = 3;
26+
unsigned int row = index / COLS;
27+
unsigned int col = index % COLS;
28+
Out[index] = M[row][col];
29+
}
30+
31+
float returnMatrixSubscriptExpr(int row, int col, fx2x3_t M) {
32+
// CHECK-LABEL: returnMatrixSubscriptExpr
33+
// CHECK: [[row_load:%.*]] = load i32, ptr [[row_ptr:%.*]], align 4
34+
// CHECK-NEXT: [[row_load_sext:%.*]] = sext i32 [[row_load]] to i64
35+
// CHECK-NEXT: [[col_load:%.*]] = load i32, ptr [[col_ptr:%.*]], align 4
36+
// CHECK-NEXT: [[col_load_sext:%.*]] = sext i32 [[col_load]] to i64
37+
// COL-CHECK-NEXT: [[col_offset:%.*]] = mul i64 [[col_load_sext]], 2
38+
// COL-CHECK-NEXT: [[col_major_index:%.*]] = add i64 [[col_offset]], [[row_load_sext]]
39+
// ROW-CHECK-NEXT: [[row_offset:%.*]] = mul i64 [[row_load_sext]], 3
40+
// ROW-CHECK-NEXT: [[row_major_index:%.*]] = add i64 [[row_offset]], [[col_load_sext]]
41+
// CHECK-NEXT: [[matrix_as_vec:%.*]] = load <6 x float>, ptr %M.addr, align 4
42+
// COL-CHECK-NEXT: [[matrix_after_extract:%.*]] = extractelement <6 x float> [[matrix_as_vec]], i64 [[col_major_index]]
43+
// ROW-CHECK-NEXT: [[matrix_after_extract:%.*]] = extractelement <6 x float> [[matrix_as_vec]], i64 [[row_major_index]]
44+
// CHECK-NEXT: ret float [[matrix_after_extract]]
45+
return M[row][col];
46+
}
47+
48+
void storeAtMatrixSubscriptExpr(int row, int col, float value) {
49+
// CHECK-LABEL: storeAtMatrixSubscriptExpr
50+
// CHECK: [[value_load:%.*]] = load float, ptr [[value_ptr:%.*]], align 4
51+
// ROW-CHECK: [[row_offset:%.*]] = mul i64 [[row_load:%.*]], 3
52+
// ROW-CHECK-NEXT: [[row_major_index:%.*]] = add i64 [[row_offset]], [[col_load:%.*]]
53+
// COL-CHECK: [[col_offset:%.*]] = mul i64 [[col_load:%.*]], 2
54+
// COL-CHECK-NEXT: [[col_major_index:%.*]] = add i64 [[col_offset]], [[row_load:%.*]]
55+
// CHECK-NEXT: [[matrix_as_vec:%.*]] = load <6 x float>, ptr @gM, align 4
56+
// ROW-CHECK-NEXT: [[matrix_after_insert:%.*]] = insertelement <6 x float> [[matrix_as_vec]], float [[value_load]], i64 [[row_major_index]]
57+
// COL-CHECK-NEXT: [[matrix_after_insert:%.*]] = insertelement <6 x float> [[matrix_as_vec]], float [[value_load]], i64 [[col_major_index]]
58+
// CHECK-NEXT: store <6 x float> [[matrix_after_insert]], ptr @gM, align 4
59+
gM[row][col] = value;
60+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// RUN: %clang_cc1 -fenable-matrix -fmatrix-memory-layout=column-major -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - -std=c++11 | FileCheck %s
2+
// RUN: %clang_cc1 -fenable-matrix -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - -std=c++11 | FileCheck %s
3+
4+
typedef float fx2x3_t __attribute__((matrix_type(2, 3)));
5+
float Out[6];
6+
7+
void binaryOpMatrixSubscriptExpr(int index, fx2x3_t M) {
8+
// CHECK-LABEL: binaryOpMatrixSubscriptExpr
9+
// CHECK: %row = alloca i32, align 4
10+
// CHECK: %col = alloca i32, align 4
11+
// CHECK: [[row_load:%.*]] = load i32, ptr %row, align 4
12+
// CHECK-NEXT: [[row_load_zext:%.*]] = zext i32 [[row_load]] to i64
13+
// CHECK-NEXT: [[col_load:%.*]] = load i32, ptr %col, align 4
14+
// CHECK-NEXT: [[col_load_zext:%.*]] = zext i32 [[col_load]] to i64
15+
// CHECK-NEXT: [[col_offset:%.*]] = mul i64 [[col_load_zext]], 2
16+
// CHECK-NEXT: [[col_major_index:%.*]] = add i64 [[col_offset]], [[row_load_zext]]
17+
// CHECK-NEXT: [[matrix_as_vec:%.*]] = load <6 x float>, ptr %M.addr, align 4
18+
// CHECK-NEXT: %matrixext = extractelement <6 x float> [[matrix_as_vec]], i64 [[col_major_index]]
19+
const unsigned int COLS = 3;
20+
unsigned int row = index / COLS;
21+
unsigned int col = index % COLS;
22+
Out[index] = M[row][col];
23+
}
24+
25+
float returnMatrixSubscriptExpr(int row, int col, fx2x3_t M) {
26+
// CHECK-LABEL: returnMatrixSubscriptExpr
27+
// CHECK: [[row_load:%.*]] = load i32, ptr %row.addr, align 4
28+
// CHECK-NEXT: [[row_load_sext:%.*]] = sext i32 [[row_load]] to i64
29+
// CHECK-NEXT: [[col_load:%.*]] = load i32, ptr %col.addr, align 4
30+
// CHECK-NEXT: [[col_load_sext:%.*]] = sext i32 [[col_load]] to i64
31+
// CHECK-NEXT: [[col_offset:%.*]] = mul i64 [[col_load_sext]], 2
32+
// CHECK-NEXT: [[col_major_index:%.*]] = add i64 [[col_offset]], [[row_load_sext]]
33+
// CHECK-NEXT: [[matrix_as_vec:%.*]] = load <6 x float>, ptr %M.addr, align 4
34+
// CHECK-NEXT: %matrixext = extractelement <6 x float> [[matrix_as_vec]], i64 [[col_major_index]]
35+
return M[row][col];
36+
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -fmatrix-memory-layout=row-major -o - | FileCheck %s --check-prefixes=CHECK,ROW-CHECK
2+
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -fmatrix-memory-layout=column-major -o - | FileCheck %s --check-prefixes=CHECK,COL-CHECK
3+
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -o - | FileCheck %s --check-prefixes=CHECK,COL-CHECK
4+
5+
RWBuffer<int> Out : register(u1);
6+
half2x3 gM;
7+
8+
9+
void binaryOpMatrixSubscriptExpr(int index, half2x3 M) {
10+
// CHECK-LABEL: binaryOpMatrixSubscriptExpr
11+
// CHECK: %row = alloca i32, align 4
12+
// CHECK: %col = alloca i32, align 4
13+
// CHECK: [[row_load:%.*]] = load i32, ptr %row, align 4
14+
// CHECK-NEXT: [[col_load:%.*]] = load i32, ptr %col, align 4
15+
// ROW-CHECK-NEXT: [[row_offset:%.*]] = mul i32 [[row_load]], 3
16+
// ROW-CHECK-NEXT: [[row_major_index:%.*]] = add i32 [[row_offset]], [[col_load]]
17+
// COL-CHECK-NEXT: [[col_offset:%.*]] = mul i32 [[col_load]], 2
18+
// COL-CHECK-NEXT: [[col_major_index:%.*]] = add i32 [[col_offset]], [[row_load]]
19+
// CHECK-NEXT: [[matrix_as_vec:%.*]] = load <6 x half>, ptr %M.addr, align 2
20+
// ROW-CHECK-NEXT: %matrixext = extractelement <6 x half> [[matrix_as_vec]], i32 [[row_major_index]]
21+
// COL-CHECK-NEXT: %matrixext = extractelement <6 x half> [[matrix_as_vec]], i32 [[col_major_index]]
22+
const uint COLS = 3;
23+
uint row = index / COLS;
24+
uint col = index % COLS;
25+
Out[index] = M[row][col];
26+
}
27+
28+
half returnMatrixSubscriptExpr(int row, int col, half2x3 M) {
29+
// CHECK-LABEL: returnMatrixSubscriptExpr
30+
// ROW-CHECK: [[row_offset:%.*]] = mul i32 [[row_load:%.*]], 3
31+
// ROW-CHECK-NEXT: [[row_major_index:%.*]] = add i32 [[row_offset]], [[col_load:%.*]]
32+
// COL-CHECK: [[col_offset:%.*]] = mul i32 [[col_load:%.*]], 2
33+
// COL-CHECK-NEXT: [[col_major_index:%.*]] = add i32 [[col_offset]], [[row_load:%.*]]
34+
// CHECK-NEXT: [[matrix_as_vec:%.*]] = load <6 x half>, ptr %M.addr, align 2
35+
// ROW-CHECK-NEXT: %matrixext = extractelement <6 x half> [[matrix_as_vec]], i32 [[row_major_index]]
36+
// COL-CHECK-NEXT: %matrixext = extractelement <6 x half> [[matrix_as_vec]], i32 [[col_major_index]]
37+
return M[row][col];
38+
}
39+
40+
void storeAtMatrixSubscriptExpr(int row, int col, half value) {
41+
// CHECK-LABEL: storeAtMatrixSubscriptExpr
42+
// CHECK: [[value_load:%.*]] = load half, ptr [[value_ptr:%.*]], align 2
43+
// ROW-CHECK: [[row_offset:%.*]] = mul i32 [[row_load:%.*]], 3
44+
// ROW-CHECK-NEXT: [[row_major_index:%.*]] = add i32 [[row_offset]], [[col_load:%.*]]
45+
// COL-CHECK: [[col_offset:%.*]] = mul i32 [[col_load:%.*]], 2
46+
// COL-CHECK-NEXT: [[col_major_index:%.*]] = add i32 [[col_offset]], [[row_load:%.*]]
47+
// CHECK-NEXT: [[matrix_as_vec:%.*]] = load <6 x half>, ptr addrspace(2) @gM, align 2
48+
// ROW-CHECK-NEXT: [[matrix_after_insert:%.*]] = insertelement <6 x half> [[matrix_as_vec]], half [[value_load]], i32 [[row_major_index]]
49+
// COL-CHECK-NEXT: [[matrix_after_insert:%.*]] = insertelement <6 x half> [[matrix_as_vec]], half [[value_load]], i32 [[col_major_index]]
50+
// CHECK-NEXT: store <6 x half> [[matrix_after_insert]], ptr addrspace(2) @gM, align 2
51+
gM[row][col] = value;
52+
}

llvm/include/llvm/IR/MatrixBuilder.h

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,19 +238,39 @@ class MatrixBuilder {
238238
else
239239
B.CreateAssumption(Cmp);
240240
}
241-
242241
/// Compute the index to access the element at (\p RowIdx, \p ColumnIdx) from
243-
/// a matrix with \p NumRows embedded in a vector.
242+
/// a matrix with \p NumRows or \p NumCols embedded in a vector depending
243+
/// on matrix major ordering.
244244
Value *CreateIndex(Value *RowIdx, Value *ColumnIdx, unsigned NumRows,
245+
unsigned NumCols, bool IsMatrixRowMajor = false,
245246
Twine const &Name = "") {
246247
unsigned MaxWidth = std::max(RowIdx->getType()->getScalarSizeInBits(),
247248
ColumnIdx->getType()->getScalarSizeInBits());
248249
Type *IntTy = IntegerType::get(RowIdx->getType()->getContext(), MaxWidth);
249250
RowIdx = B.CreateZExt(RowIdx, IntTy);
250251
ColumnIdx = B.CreateZExt(ColumnIdx, IntTy);
252+
if (IsMatrixRowMajor) {
253+
Value *NumColsV = B.getIntN(MaxWidth, NumCols);
254+
return CreateRowMajorIndex(RowIdx, ColumnIdx, NumColsV, Name);
255+
}
251256
Value *NumRowsV = B.getIntN(MaxWidth, NumRows);
257+
return CreateColumnMajorIndex(RowIdx, ColumnIdx, NumRowsV, Name);
258+
}
259+
260+
private:
261+
/// Compute the index to access the element at (\p RowIdx, \p ColumnIdx) from
262+
/// a matrix with \p NumRows embedded in a vector.
263+
Value *CreateColumnMajorIndex(Value *RowIdx, Value *ColumnIdx,
264+
Value *NumRowsV, Twine const &Name) {
252265
return B.CreateAdd(B.CreateMul(ColumnIdx, NumRowsV), RowIdx);
253266
}
267+
268+
/// Compute the index to access the element at (\p RowIdx, \p ColumnIdx) from
269+
/// a matrix with \p NumCols embedded in a vector.
270+
Value *CreateRowMajorIndex(Value *RowIdx, Value *ColumnIdx, Value *NumColsV,
271+
Twine const &Name) {
272+
return B.CreateAdd(B.CreateMul(RowIdx, NumColsV), ColumnIdx);
273+
}
254274
};
255275

256276
} // end namespace llvm

0 commit comments

Comments
 (0)