Skip to content

Commit 160561e

Browse files
committed
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into probprog-trace-operand
2 parents b92a733 + e09747c commit 160561e

File tree

3 files changed

+51
-6
lines changed

3 files changed

+51
-6
lines changed

Project.toml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>", "Mosè Giordano <[email protected]>"]
4-
version = "0.2.143"
4+
version = "0.2.144"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -43,8 +43,8 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
4343
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
4444
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"
4545

46-
[sources]
47-
ReactantCore = {path = "lib/ReactantCore"}
46+
[sources.ReactantCore]
47+
path = "lib/ReactantCore"
4848

4949
[extensions]
5050
ReactantAbstractFFTsExt = "AbstractFFTs"
@@ -90,13 +90,13 @@ PythonCall = "0.9.25"
9090
Random = "1.10"
9191
Random123 = "1.7"
9292
ReactantCore = "0.1.15"
93-
Reactant_jll = "0.0.214"
93+
Reactant_jll = "0.0.216"
9494
ScopedValues = "1.3.0"
9595
Scratch = "1.2"
9696
Sockets = "1.10"
9797
SpecialFunctions = "2.4"
9898
Statistics = "1.10"
99-
YaoBlocks = "0.13"
99+
YaoBlocks = "0.13, 0.14"
100100
julia = "1.10"
101101

102102
[extras]

deps/ReactantExtra/WORKSPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ http_archive(
1111
urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
1212
)
1313

14-
ENZYMEXLA_COMMIT = "2527ca4bb8fa4499cd10ffb42ce4c2cda3738e91"
14+
ENZYMEXLA_COMMIT = "d95dbb25bac5b01b4c96816234eff81c00e2513e"
1515

1616
ENZYMEXLA_SHA256 = ""
1717

src/mlir/Dialects/EnzymeXLA.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,31 @@ function gpu_wrapper(
219219
)
220220
end
221221

222+
function ml_gelu(
223+
input::Value;
224+
result=nothing::Union{Nothing,IR.Type},
225+
gelu_approximation,
226+
location=Location(),
227+
)
228+
op_ty_results = IR.Type[]
229+
operands = Value[input,]
230+
owned_regions = Region[]
231+
successors = Block[]
232+
attributes = NamedAttribute[namedattribute("gelu_approximation", gelu_approximation),]
233+
!isnothing(result) && push!(op_ty_results, result)
234+
235+
return create_operation(
236+
"enzymexla.ml.gelu",
237+
location;
238+
operands,
239+
owned_regions,
240+
successors,
241+
attributes,
242+
results=(length(op_ty_results) == 0 ? nothing : op_ty_results),
243+
result_inference=(length(op_ty_results) == 0 ? true : false),
244+
)
245+
end
246+
222247
function get_stream(; result::IR.Type, location=Location())
223248
op_ty_results = IR.Type[result,]
224249
operands = Value[]
@@ -491,6 +516,26 @@ function linalg_qr(
491516
)
492517
end
493518

519+
function ml_relu(input::Value; result=nothing::Union{Nothing,IR.Type}, location=Location())
520+
op_ty_results = IR.Type[]
521+
operands = Value[input,]
522+
owned_regions = Region[]
523+
successors = Block[]
524+
attributes = NamedAttribute[]
525+
!isnothing(result) && push!(op_ty_results, result)
526+
527+
return create_operation(
528+
"enzymexla.ml.relu",
529+
location;
530+
operands,
531+
owned_regions,
532+
successors,
533+
attributes,
534+
results=(length(op_ty_results) == 0 ? nothing : op_ty_results),
535+
result_inference=(length(op_ty_results) == 0 ? true : false),
536+
)
537+
end
538+
494539
function rotate(
495540
operand::Value;
496541
result=nothing::Union{Nothing,IR.Type},

0 commit comments

Comments
 (0)