Skip to content

Commit a93b618

Browse files
authored
Add more blas fns (#1544)
* Add more blas fns * fix * Update Project.toml
1 parent eb79121 commit a93b618

File tree

2 files changed

+14
-11
lines changed

2 files changed

+14
-11
lines changed

Project.toml

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,16 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250"
1616
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1717
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1818

19+
[weakdeps]
20+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
21+
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
22+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
23+
24+
[extensions]
25+
EnzymeChainRulesCoreExt = "ChainRulesCore"
26+
EnzymeSpecialFunctionsExt = "SpecialFunctions"
27+
EnzymeStaticArraysExt = "StaticArrays"
28+
1929
[compat]
2030
CEnum = "0.4, 0.5"
2131
ChainRulesCore = "1"
@@ -29,17 +39,7 @@ SpecialFunctions = "1, 2"
2939
StaticArrays = "1"
3040
julia = "1.6"
3141

32-
[extensions]
33-
EnzymeChainRulesCoreExt = "ChainRulesCore"
34-
EnzymeSpecialFunctionsExt = "SpecialFunctions"
35-
EnzymeStaticArraysExt = "StaticArrays"
36-
3742
[extras]
3843
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3944
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
4045
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
41-
42-
[weakdeps]
43-
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
44-
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
45-
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

src/compiler.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5154,10 +5154,13 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget};
51545154
check_ir(job, mod)
51555155

51565156
disableFallback = String[]
5157+
5158+
ForwardModeDerivatives = ("dot","gemm","gemv","axpy","copy","scal")
5159+
ReverseModeDerivatives = ("dot","gemm","gemv","axpy","copy","scal", "trmv", "syrk", "trmm", "trsm")
51575160
# Tablegen BLAS does not support forward mode yet
51585161
if !(mode == API.DEM_ForwardMode && Enzyme.API.runtimeActivity())
51595162
for ty in ("s", "d")
5160-
for func in ("dot","gemm","gemv","axpy","copy","scal")
5163+
for func in (mode == API.DEM_ForwardMode ? ForwardModeDerivatives : ReverseModeDerivatives)
51615164
for prefix in ("", "cblas_")
51625165
for ending in ("", "_", "64_", "_64_")
51635166
push!(disableFallback, prefix*ty*func*ending)

0 commit comments

Comments
 (0)