Skip to content

Commit 525a2b5

Browse files
committed
Create CUDA and GenericTensorNetworks extensions
1 parent 457960f commit 525a2b5

File tree

6 files changed

+33
-18
lines changed

6 files changed

+33
-18
lines changed

Project.toml

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@ version = "0.4.0"
55

66
[deps]
77
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
8-
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
98
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
10-
GenericTensorNetworks = "3521c873-ad32-4bb4-b63d-f4f178f42b49"
119
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1210
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
1311
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
@@ -16,13 +14,21 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1614
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1715
TropicalNumbers = "b3a74e9c-7526-4576-a4eb-79c0d4c32334"
1816

17+
[weakdeps]
18+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
19+
GenericTensorNetworks = "3521c873-ad32-4bb4-b63d-f4f178f42b49"
20+
21+
[extensions]
22+
TensorInferenceCUDAExt = "CUDA"
23+
TensorInferenceGTNExt = "GenericTensorNetworks"
24+
1925
[compat]
20-
CUDA = "4"
26+
CUDA = "4, 5"
2127
DocStringExtensions = "0.8.6, 0.9"
2228
GenericTensorNetworks = "1"
2329
OMEinsum = "0.7"
2430
PrecompileTools = "1"
2531
Requires = "1"
2632
StatsBase = "0.34"
27-
TropicalNumbers = "0.5.4"
28-
julia = "1.3"
33+
TropicalNumbers = "0.5.4, 0.6"
34+
julia = "1.9"

src/cuda.jl renamed to ext/TensorInferenceCUDAExt.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
using .CUDA: CuArray
1+
module TensorInferenceCUDAExt
2+
using CUDA: CuArray
3+
import CUDA
4+
import TensorInference: match_arraytype, keep_only!, onehot_like, togpu
25

36
function onehot_like(A::CuArray, j)
47
mask = zero(A)
@@ -15,3 +18,7 @@ function keep_only!(x::CuArray{T}, j) where T
1518
CUDA.@allowscalar x[j] = hotvalue
1619
return x
1720
end
21+
22+
togpu(x::AbstractArray) = CuArray(x)
23+
24+
end
Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
using .GenericTensorNetworks: generate_tensors, GraphProblem, flavors, labels
2-
3-
# update models
4-
export update_temperature
1+
module TensorInferenceGTNExt
2+
using TensorInference, TensorInference.OMEinsum
3+
using TensorInference: TYPEDSIGNATURES, Factor
4+
import TensorInference: update_temperature
5+
using GenericTensorNetworks: generate_tensors, GraphProblem, flavors, labels
56

67
"""
78
$TYPEDSIGNATURES
@@ -64,4 +65,5 @@ It is about one or two hours of works. If you need it, please file an issue to l
6465
end
6566

6667
@info "`TensorInference` loaded `GenericTensorNetworks` extension successfully,
67-
`TensorNetworkModel` and `MMAPModel` can be used for converting a `GraphProblem` to a probabilistic model now."
68+
`TensorNetworkModel` and `MMAPModel` can be used for converting a `GraphProblem` to a probabilistic model now."
69+
end

src/TensorInference.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ export sample
3434
# MMAP
3535
export MMAPModel
3636

37+
# for GenericTensorNetworks
38+
export update_temperature
39+
function update_temperature end
40+
3741
include("Core.jl")
3842
include("RescaledArray.jl")
3943
include("utils.jl")
@@ -42,12 +46,6 @@ include("map.jl")
4246
include("mmap.jl")
4347
include("sampling.jl")
4448

45-
using Requires
46-
function __init__()
47-
@require CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" include("cuda.jl")
48-
@require GenericTensorNetworks = "3521c873-ad32-4bb4-b63d-f4f178f42b49" include("generictensornetworks.jl")
49-
end
50-
5149
# import PrecompileTools
5250
# PrecompileTools.@setup_workload begin
5351
# # Putting some things in `@setup_workload` instead of `@compile_workload` can reduce the size of the

src/mar.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ function adapt_tensors(code, tensors, evidence; usecuda, rescale)
77
map(tensors, ixs) do t, ix
88
dims = map(ixi -> ixi keys(evidence) ? Colon() : ((evidence[ixi] + 1):(evidence[ixi] + 1)), ix)
99
t2 = t[dims...]
10-
t3 = usecuda ? CuArray(t2) : t2
10+
t3 = usecuda ? togpu(t2) : t2
1111
rescale ? rescale_array(t3) : t3
1212
end
1313
end

src/utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,3 +305,5 @@ function get_artifact_path(artifact_name::String)
305305
artifact_hash = Pkg.Artifacts.artifact_hash(artifact_name, artifact_toml)
306306
return Pkg.Artifacts.artifact_path(artifact_hash)
307307
end
308+
309+
togpu(x) = error("You must import CUDA with `using CUDA` before using GPU!")

0 commit comments

Comments
 (0)