Skip to content

Commit 495fde3

Browse files
Pangorawwsmoses
andauthored
tablegen: Add StaticSelect to select based on static condition (#2206)
* tablegen: Add StaticIf to select based on static condition * rename to StaticSelect and implement SelectIfActive and SelectIfComplex with it * define for llvm * put vector mode for LLVM back * basic use analysis * StaticSelect use analysis * fixup --------- Co-authored-by: William S. Moses <[email protected]>
1 parent ac3be7e commit 495fde3

File tree

3 files changed

+273
-206
lines changed

3 files changed

+273
-206
lines changed

enzyme/Enzyme/InstructionDerivatives.td

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@ class ConstantFP<string val> : Operation</*primal*/0, /*shadow*/0> {
1313
string value = val;
1414
}
1515

16+
class StaticSelect<string condition_> : Operation</*primal*/0, /*shadow*/0, /*custom*/0> {
17+
string condition = condition_;
18+
}
19+
20+
def SelectIfActive : StaticSelect<"!gutils->isConstantValue(imVal)">;
1621

1722
class Attribute<string name_> {
1823
string name = name_;
@@ -62,9 +67,6 @@ class Inst<string mnemonic> : Operation</*primal*/1, /*shadow*/0> {
6267
def TypeOf : Operation</*primal*/0, /*shadow*/0> {
6368
}
6469
def VectorSize : Operation</*primal*/0, /*shadow*/0> {
65-
}
66-
def SelectIfActive : Operation</*primal*/0, /*shadow*/0, /*custom*/1> {
67-
6870
}
6971

7072
// Define ops to rewrite.

enzyme/Enzyme/MLIR/Implementations/Common.td

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,20 @@ def DiffeRet : DiffeRetIndex<[-1]>;
8686
def Shadow : Operation</*primal*/0, /*shadow*/1> {
8787
}
8888

89-
class GlobalExpr<bit uses_primal, bit uses_shadow, string val> : Operation<uses_primal, uses_shadow>{
89+
class GlobalExpr<bit uses_primal, bit uses_shadow, string val> : Operation<uses_primal, uses_shadow> {
9090
string value = val;
9191
}
9292

93+
// Class for a dag operator that generates either a or b
94+
// It can then be used with a two or three arguments.
95+
// The two arguments version is (StaticSelect a, b)
96+
// The three arguments version accepts a name as a first argument
97+
// which is then available in the condition as a `Value` under the
98+
// variable `imVal`.
99+
class StaticSelect<string condition_> : Operation</*usesPrimal*/0, /*usesShadow*/0, /*usesCustom*/0> {
100+
string condition = condition_;
101+
}
102+
93103
class Inst<string mnemonic, string dialect_, string postop_=""> : Operation</*primal*/1, /*shadow*/0> {
94104
string name = mnemonic;
95105
string dialect = dialect_;
@@ -99,13 +109,14 @@ class Inst<string mnemonic, string dialect_, string postop_=""> : Operation</*p
99109
def Op {
100110
}
101111

102-
def SelectIfActive : Operation</*primal*/0, /*shadow*/0, /*custom*/1> {
103-
104-
}
105-
106-
def SelectIfComplex : Operation</*primal*/1, /*shadow*/0, /*custom*/0> {
112+
def SelectIfActive : StaticSelect<"!gutils->isConstantValue(imVal)">;
107113

108-
}
114+
def SelectIfComplex : StaticSelect<[{
115+
auto ty = imVal.getType();
116+
ty.isa<ComplexType>() ||
117+
ty.isa<TensorType>() &&
118+
ty.cast<TensorType>().getElementType().isa<ComplexType>();
119+
}]>;
109120

110121
class ConstantFP<string val, string dialect_, string op_, string type_=""> : Operation</*primal*/0, /*shadow*/0> {
111122
string value = val;

0 commit comments

Comments
 (0)