Skip to content

Commit c1d3475

Browse files
KanclerzPiotrigcbot
authored andcommitted
Support aggregate with bools promotion in functions
This PR enables support of structs and arrays with bools in function arguments and return types.
1 parent cd5c825 commit c1d3475

File tree

3 files changed

+158
-10
lines changed

3 files changed

+158
-10
lines changed

IGC/AdaptorOCL/preprocess_spvir/PromoteBools.cpp

Lines changed: 64 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,7 @@ SPDX-License-Identifier: MIT
99
#include "PromoteBools.h"
1010
#include "Compiler/IGCPassSupport.h"
1111
#include "common/LLVMWarningsPush.hpp"
12-
#include "llvmWrapper/IR/DerivedTypes.h"
13-
#include "llvmWrapper/IR/Instructions.h"
1412
#include "llvmWrapper/IR/Type.h"
15-
#include "llvmWrapper/IR/Function.h"
1613
#include "llvmWrapper/Support/Alignment.h"
1714
#include "llvmWrapper/Transforms/Utils/Cloning.h"
1815
#include <llvm/IR/Module.h>
@@ -102,7 +99,17 @@ bool PromoteBools::runOnModule(Module &module) {
10299
return changed;
103100
}
104101

102+
Value *PromoteBools::convertI8ToI1(Value *value, Instruction *insertBefore) {
103+
IRBuilder<> builder(insertBefore);
104+
return convertI8ToI1(value, builder);
105+
}
106+
105107
Value *PromoteBools::convertI1ToI8(Value *value, Instruction *insertBefore) {
108+
IRBuilder<> builder(insertBefore);
109+
return convertI1ToI8(value, builder);
110+
}
111+
112+
Value *PromoteBools::convertI1ToI8(Value *value, IRBuilder<> &builder) {
106113
if (!isIntegerTy(value, 1)) {
107114
return value;
108115
}
@@ -112,15 +119,14 @@ Value *PromoteBools::convertI1ToI8(Value *value, Instruction *insertBefore) {
112119
return trunc->getOperand(0);
113120
}
114121

115-
IRBuilder<> builder(insertBefore);
116122
auto zext = builder.CreateZExt(value, getOrCreatePromotedType(value->getType()));
117123
if (isa<Instruction>(zext) && isa<Instruction>(value)) {
118124
dyn_cast<Instruction>(zext)->setDebugLoc(dyn_cast<Instruction>(value)->getDebugLoc());
119125
}
120126
return zext;
121127
}
122128

123-
Value *PromoteBools::convertI8ToI1(Value *value, Instruction *insertBefore) {
129+
Value *PromoteBools::convertI8ToI1(Value *value, IRBuilder<> &builder) {
124130
if (!isIntegerTy(value, 8)) {
125131
return value;
126132
}
@@ -130,26 +136,75 @@ Value *PromoteBools::convertI8ToI1(Value *value, Instruction *insertBefore) {
130136
return zext->getOperand(0);
131137
}
132138

133-
IRBuilder<> builder(insertBefore);
134139
auto trunc = builder.CreateTrunc(value, createDemotedType(value->getType()));
135140
if (isa<Instruction>(trunc) && isa<Instruction>(value)) {
136141
dyn_cast<Instruction>(trunc)->setDebugLoc(dyn_cast<Instruction>(value)->getDebugLoc());
137142
}
138143
return trunc;
139144
}
140145

146+
Value *PromoteBools::castAggregate(Value *value, Type *desiredType, IRBuilder<> &builder) {
147+
148+
if (auto *srcST = dyn_cast<StructType>(value->getType())) {
149+
auto *dstST = dyn_cast<StructType>(desiredType);
150+
151+
IGC_ASSERT(srcST && dstST);
152+
IGC_ASSERT(srcST->getNumElements() == dstST->getNumElements());
153+
154+
Value *Accum = PoisonValue::get(dstST);
155+
156+
for (unsigned i = 0; i < srcST->getNumElements(); i++) {
157+
auto *dstElType = dstST->getElementType(i);
158+
auto *srcElType = srcST->getElementType(i);
159+
160+
Value *ExtVal = builder.CreateExtractValue(value, i);
161+
Value *ExtValOrCast = (srcElType == dstElType) ? ExtVal : castTo(ExtVal, dstElType, builder);
162+
Accum = builder.CreateInsertValue(Accum, ExtValOrCast, i);
163+
}
164+
return Accum;
165+
}
166+
167+
if (auto *srcAT = dyn_cast<ArrayType>(value->getType())) {
168+
auto *dstAT = dyn_cast<ArrayType>(desiredType);
169+
170+
IGC_ASSERT(srcAT && dstAT);
171+
IGC_ASSERT(srcAT->getNumElements() == dstAT->getNumElements());
172+
173+
auto *dstElType = dstAT->getElementType();
174+
auto *srcElType = srcAT->getElementType();
175+
Value *Accum = PoisonValue::get(dstAT);
176+
177+
for (unsigned i = 0; i < srcAT->getNumElements(); i++) {
178+
Value *ExtVal = builder.CreateExtractValue(value, i);
179+
Value *ExtValOrCast = (srcElType == dstElType) ? ExtVal : castTo(ExtVal, dstElType, builder);
180+
Accum = builder.CreateInsertValue(Accum, ExtValOrCast, i);
181+
}
182+
183+
return Accum;
184+
}
185+
186+
return nullptr;
187+
}
188+
141189
Value *PromoteBools::castTo(Value *value, Type *desiredType, Instruction *insertBefore) {
190+
IRBuilder<> builder(insertBefore);
191+
return castTo(value, desiredType, builder);
192+
}
193+
Value *PromoteBools::castTo(Value *value, Type *desiredType, IRBuilder<> &builder) {
194+
142195
if (value->getType() == desiredType) {
143196
return value;
144197
}
145198

199+
if (desiredType->isAggregateType())
200+
return castAggregate(value, desiredType, builder);
201+
146202
if (isIntegerTy(value, 8) && desiredType->isIntegerTy(1)) {
147-
return convertI8ToI1(value, insertBefore);
203+
return convertI8ToI1(value, builder);
148204
} else if (isIntegerTy(value, 1) && desiredType->isIntegerTy(8)) {
149-
return convertI1ToI8(value, insertBefore);
205+
return convertI1ToI8(value, builder);
150206
}
151207

152-
IRBuilder<> builder(insertBefore);
153208
return builder.CreateBitCast(value, desiredType);
154209
}
155210

IGC/AdaptorOCL/preprocess_spvir/PromoteBools.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,14 @@ class PromoteBools : public llvm::ModulePass, public llvm::InstVisitor<PromoteBo
4545
private:
4646
bool changed;
4747

48+
llvm::Value *convertI1ToI8(llvm::Value *argument, llvm::IRBuilder<>& builder);
4849
llvm::Value *convertI1ToI8(llvm::Value *argument, llvm::Instruction *insertBefore);
50+
llvm::Value *convertI8ToI1(llvm::Value *argument, llvm::IRBuilder<>& builder);
4951
llvm::Value *convertI8ToI1(llvm::Value *argument, llvm::Instruction *insertBefore);
52+
llvm::Value *castTo(llvm::Value *value, llvm::Type *desiredType, llvm::IRBuilder<>& builder);
5053
llvm::Value *castTo(llvm::Value *value, llvm::Type *desiredType, llvm::Instruction *insertBefore);
54+
llvm::Value *castAggregate(llvm::Value *value, llvm::Type *desiredType, llvm::IRBuilder<>& builder);
55+
5156
void cleanUp(llvm::Module &module);
5257

5358
// Checking if type needs promotion

IGC/Compiler/tests/PromoteBools/functions-and-calls.ll

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,15 @@
1111

1212
%struct_without_bools = type { i8 }
1313
%struct = type { [4 x <8 x i1*>], [4 x <8 x i1>*]* }
14+
%inner = type { i1 }
15+
%struct2 = type { i32, i1, %inner }
1416

1517
; CHECK: %struct_without_bools = type { i8 }
1618
; CHECK: %struct = type { [4 x <8 x i8*>], [4 x <8 x i8>*]* }
17-
19+
; CHECK: %struct2 = type { i32, i8, %inner }
20+
; CHECK: %inner = type { i8 }
21+
; CHECK: %struct2.unpromoted = type { i32, i1, %inner.unpromoted }
22+
; CHECK: %inner.unpromoted = type { i1 }
1823

1924
define spir_func void @fun_struct_without_bools(i1 %input1, %struct_without_bools %input2) {
2025
ret void
@@ -30,6 +35,7 @@ define spir_func i1 @callee0(i1 %input) {
3035
; CHECK: define spir_func i8 @callee0(i8 %input)
3136
; CHECK-NEXT: ret i8 %input
3237

38+
3339
define spir_func i1 @callee1(%struct addrspace(1)* %input_struct)
3440
{
3541
%1 = load %struct, %struct addrspace(1)* %input_struct
@@ -40,11 +46,90 @@ define spir_func i1 @callee1(%struct addrspace(1)* %input_struct)
4046
; CHECK-NEXT: %1 = load %struct, %struct addrspace(1)* %input_struct
4147
; CHECK-NEXT: ret i8 1
4248

49+
50+
define spir_func %struct2 @callee2(%struct2 %input_struct) {
51+
ret %struct2 %input_struct
52+
}
53+
54+
; CHECK: define spir_func %struct2 @callee2(%struct2 %input_struct) {
55+
; CHECK-NEXT: %1 = extractvalue %struct2 %input_struct, 0
56+
; CHECK-NEXT: %2 = insertvalue %struct2.unpromoted poison, i32 %1, 0
57+
; CHECK-NEXT: %3 = extractvalue %struct2 %input_struct, 1
58+
; CHECK-NEXT: %4 = trunc i8 %3 to i1
59+
; CHECK-NEXT: %5 = insertvalue %struct2.unpromoted %2, i1 %4, 1
60+
; CHECK-NEXT: %6 = extractvalue %struct2 %input_struct, 2
61+
; CHECK-NEXT: %7 = extractvalue %inner %6, 0
62+
; CHECK-NEXT: %8 = trunc i8 %7 to i1
63+
; CHECK-NEXT: %9 = insertvalue %inner.unpromoted poison, i1 %8, 0
64+
; CHECK-NEXT: %10 = insertvalue %struct2.unpromoted %5, %inner.unpromoted %9, 2
65+
; CHECK-NEXT: %11 = extractvalue %struct2.unpromoted %10, 0
66+
; CHECK-NEXT: %12 = insertvalue %struct2 poison, i32 %11, 0
67+
; CHECK-NEXT: %13 = extractvalue %struct2.unpromoted %10, 1
68+
; CHECK-NEXT: %14 = zext i1 %13 to i8
69+
; CHECK-NEXT: %15 = insertvalue %struct2 %12, i8 %14, 1
70+
; CHECK-NEXT: %16 = extractvalue %struct2.unpromoted %10, 2
71+
; CHECK-NEXT: %17 = extractvalue %inner.unpromoted %16, 0
72+
; CHECK-NEXT: %18 = zext i1 %17 to i8
73+
; CHECK-NEXT: %19 = insertvalue %inner poison, i8 %18, 0
74+
; CHECK-NEXT: %20 = insertvalue %struct2 %15, %inner %19, 2
75+
; CHECK-NEXT: ret %struct2 %20
76+
77+
78+
define spir_func [ 2 x i1 ] @callee3( [ 2 x i1 ] %input_array) {
79+
ret [ 2 x i1 ] %input_array
80+
}
81+
82+
; CHECK: define spir_func [2 x i8] @callee3([2 x i8] %input_array) {
83+
; CHECK-NEXT: %1 = extractvalue [2 x i8] %input_array, 0
84+
; CHECK-NEXT: %2 = trunc i8 %1 to i1
85+
; CHECK-NEXT: %3 = insertvalue [2 x i1] poison, i1 %2, 0
86+
; CHECK-NEXT: %4 = extractvalue [2 x i8] %input_array, 1
87+
; CHECK-NEXT: %5 = trunc i8 %4 to i1
88+
; CHECK-NEXT: %6 = insertvalue [2 x i1] %3, i1 %5, 1
89+
; CHECK-NEXT: %7 = extractvalue [2 x i1] %6, 0
90+
; CHECK-NEXT: %8 = zext i1 %7 to i8
91+
; CHECK-NEXT: %9 = insertvalue [2 x i8] poison, i8 %8, 0
92+
; CHECK-NEXT: %10 = extractvalue [2 x i1] %6, 1
93+
; CHECK-NEXT: %11 = zext i1 %10 to i8
94+
; CHECK-NEXT: %12 = insertvalue [2 x i8] %9, i8 %11, 1
95+
; CHECK-NEXT: ret [2 x i8] %12
96+
97+
98+
define spir_func [ 2 x %inner ] @calle4( [ 2 x %inner ] %input_array) {
99+
ret [ 2 x %inner ] %input_array
100+
}
101+
102+
; CHECK: define spir_func [2 x %inner] @calle4([2 x %inner] %input_array) {
103+
; CHECK-NEXT: %1 = extractvalue [2 x %inner] %input_array, 0
104+
; CHECK-NEXT: %2 = extractvalue %inner %1, 0
105+
; CHECK-NEXT: %3 = trunc i8 %2 to i1
106+
; CHECK-NEXT: %4 = insertvalue %inner.unpromoted poison, i1 %3, 0
107+
; CHECK-NEXT: %5 = insertvalue [2 x %inner.unpromoted] poison, %inner.unpromoted %4, 0
108+
; CHECK-NEXT: %6 = extractvalue [2 x %inner] %input_array, 1
109+
; CHECK-NEXT: %7 = extractvalue %inner %6, 0
110+
; CHECK-NEXT: %8 = trunc i8 %7 to i1
111+
; CHECK-NEXT: %9 = insertvalue %inner.unpromoted poison, i1 %8, 0
112+
; CHECK-NEXT: %10 = insertvalue [2 x %inner.unpromoted] %5, %inner.unpromoted %9, 1
113+
; CHECK-NEXT: %11 = extractvalue [2 x %inner.unpromoted] %10, 0
114+
; CHECK-NEXT: %12 = extractvalue %inner.unpromoted %11, 0
115+
; CHECK-NEXT: %13 = zext i1 %12 to i8
116+
; CHECK-NEXT: %14 = insertvalue %inner poison, i8 %13, 0
117+
; CHECK-NEXT: %15 = insertvalue [2 x %inner] poison, %inner %14, 0
118+
; CHECK-NEXT: %16 = extractvalue [2 x %inner.unpromoted] %10, 1
119+
; CHECK-NEXT: %17 = extractvalue %inner.unpromoted %16, 0
120+
; CHECK-NEXT: %18 = zext i1 %17 to i8
121+
; CHECK-NEXT: %19 = insertvalue %inner poison, i8 %18, 0
122+
; CHECK-NEXT: %20 = insertvalue [2 x %inner] %15, %inner %19, 1
123+
; CHECK-NEXT: ret [2 x %inner] %20
124+
43125
define spir_func void @caller(i1 %input, %struct addrspace(1)* %input_struct) {
44126
%1 = call i1 @callee0(i1 false)
45127
%2 = call i1 @callee0(i1 true)
46128
%3 = call i1 @callee0(i1 %input)
47129
%4 = call i1 @callee1(%struct addrspace(1)* %input_struct)
130+
%5 = call %struct2 @callee2( %struct2 { i32 5, i1 false, %inner { i1 true } })
131+
%6 = call [ 2 x i1 ] @callee3( [ 2 x i1 ] [ i1 true, i1 false])
132+
%7 = call [ 2 x %inner ] @calle4( [ 2 x %inner ] [ %inner { i1 true }, %inner { i1 false} ])
48133
ret void
49134
}
50135

@@ -53,4 +138,7 @@ define spir_func void @caller(i1 %input, %struct addrspace(1)* %input_struct) {
53138
; CHECK-NEXT: %2 = call i8 @callee0(i8 1)
54139
; CHECK-NEXT: %3 = call i8 @callee0(i8 %input)
55140
; CHECK-NEXT: %4 = call i8 @callee1(%struct addrspace(1)* %input_struct)
141+
; CHECK-NEXT: %5 = call %struct2 @callee2(%struct2 { i32 5, i8 0, %inner { i8 1 } })
142+
; CHECK-NEXT: %6 = call [2 x i8] @callee3([2 x i8] c"\01\00")
143+
; CHECK-NEXT: %7 = call [2 x %inner] @calle4([2 x %inner] [%inner { i8 1 }, %inner zeroinitializer])
56144
; CHECK-NEXT: ret void

0 commit comments

Comments
 (0)