Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
312 changes: 312 additions & 0 deletions Manifest.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,312 @@
# This file is machine-generated - editing it directly is not advised

[[AbstractFFTs]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "485ee0867925449198280d4af84bdb46a2a404d0"
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
version = "1.0.1"

[[Adapt]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "f1b523983a58802c4695851926203b36e28f09db"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "3.3.0"

[[ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"

[[Artifacts]]
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"

[[BFloat16s]]
deps = ["LinearAlgebra", "Test"]
git-tree-sha1 = "4af69e205efc343068dc8722b8dfec1ade89254a"
uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
version = "0.1.0"

[[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

[[CEnum]]
git-tree-sha1 = "215a9aa4a1f23fbd05b92769fdd62559488d70e9"
uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
version = "0.4.1"

[[CUDA]]
deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "DataStructures", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "MacroTools", "Memoize", "NNlib", "Printf", "Random", "Reexport", "Requires", "SparseArrays", "Statistics", "TimerOutputs"]
git-tree-sha1 = "6893a46f357eabd44ce0fc1f9a264120a1a3a732"
uuid = "052768ef-5323-5732-b1bb-66c8b64840ba"
version = "2.6.3"

[[CUDAKernels]]
deps = ["Adapt", "CUDA", "Cassette", "KernelAbstractions", "SpecialFunctions", "StaticArrays"]
git-tree-sha1 = "0daaee29c30f1ea9e1a8f7c9d92b00d07aba0fa0"
uuid = "72cfdca4-0801-4ab0-bf6a-d52aa10adc57"
version = "0.1.0"

[[Cassette]]
git-tree-sha1 = "742fbff99a2798f02bd37d25087efb5615b5a207"
uuid = "7057c7e9-c182-5462-911a-8362d720325c"
version = "0.3.5"

[[ChainRulesCore]]
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
git-tree-sha1 = "42e3c181483fbd2c416087a0a93838803e358358"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.9.38"

[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "ac4132ad78082518ec2037ae5770b6e796f7f956"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "3.27.0"

[[CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"

[[DataStructures]]
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "4437b64df1e0adccc3e5d1adbc3ac741095e4677"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.18.9"

[[Dates]]
deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"

[[DelimitedFiles]]
deps = ["Mmap"]
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"

[[Distributed]]
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

[[Downloads]]
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"

[[ExprTools]]
git-tree-sha1 = "10407a39b87f29d47ebaca8edbc75d7c302ff93e"
uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
version = "0.1.3"

[[GPUArrays]]
deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization"]
git-tree-sha1 = "9c95b2fd5c16bc7f97371e9f92f0fef77e0f5957"
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
version = "6.2.2"

[[GPUCompiler]]
deps = ["DataStructures", "ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "Serialization", "TimerOutputs", "UUIDs"]
git-tree-sha1 = "ef2839b063e158672583b9c09d2cf4876a8d3d55"
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
version = "0.10.0"

[[InteractiveUtils]]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"

[[JLLWrappers]]
deps = ["Preferences"]
git-tree-sha1 = "642a199af8b68253517b80bd3bfd17eb4e84df6e"
uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
version = "1.3.0"

[[KernelAbstractions]]
deps = ["Adapt", "Cassette", "InteractiveUtils", "MacroTools", "SpecialFunctions", "StaticArrays", "UUIDs"]
git-tree-sha1 = "0158a3efe86a7c0a3162ded188172a0ad302353b"
uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
version = "0.6.0"

[[LLVM]]
deps = ["CEnum", "Libdl", "Printf", "Unicode"]
git-tree-sha1 = "b616937c31337576360cb9fb872ec7633af7b194"
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
version = "3.6.0"

[[LazyArtifacts]]
deps = ["Artifacts", "Pkg"]
uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3"

[[LibCURL]]
deps = ["LibCURL_jll", "MozillaCACerts_jll"]
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"

[[LibCURL_jll]]
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"

[[LibGit2]]
deps = ["Base64", "NetworkOptions", "Printf", "SHA"]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"

[[LibSSH2_jll]]
deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"

[[Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"

[[LinearAlgebra]]
deps = ["Libdl"]
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[[Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

[[MacroTools]]
deps = ["Markdown", "Random"]
git-tree-sha1 = "6a8a2a625ab0dea913aba95c11370589e0239ff0"
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
version = "0.5.6"

[[Markdown]]
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"

[[MbedTLS_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"

[[Memoize]]
deps = ["MacroTools"]
git-tree-sha1 = "2b1dfcba103de714d31c033b5dacc2e4a12c7caa"
uuid = "c03570c3-d221-55d1-a50c-7939bbd78826"
version = "0.4.4"

[[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"

[[MozillaCACerts_jll]]
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"

[[NNlib]]
deps = ["Adapt", "ChainRulesCore", "Compat", "LinearAlgebra", "Pkg", "Requires", "Statistics"]
git-tree-sha1 = "723c0d5252bf95808f934b2384519dd325869f40"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.7.18"

[[NetworkOptions]]
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"

[[OpenSpecFun_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "9db77584158d0ab52307f8c04f8e7c08ca76b5b3"
uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
version = "0.5.3+4"

[[OrderedCollections]]
git-tree-sha1 = "4fa2ba51070ec13fcc7517db714445b4ab986bdf"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.4.0"

[[Pkg]]
deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"

[[Preferences]]
deps = ["TOML"]
git-tree-sha1 = "ea79e4c9077208cd3bc5d29631a26bc0cff78902"
uuid = "21216c6a-2e73-6563-6e65-726566657250"
version = "1.2.1"

[[Printf]]
deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"

[[REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"

[[Random]]
deps = ["Serialization"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[[Reexport]]
git-tree-sha1 = "57d8440b0c7d98fc4f889e478e80f268d534c9d5"
uuid = "189a3867-3050-52da-a836-e630ba90ab69"
version = "1.0.0"

[[Requires]]
deps = ["UUIDs"]
git-tree-sha1 = "4036a3bd08ac7e968e27c203d45f5fff15020621"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "1.1.3"

[[SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"

[[Scratch]]
deps = ["Dates"]
git-tree-sha1 = "ad4b278adb62d185bbcb6864dc24959ab0627bf6"
uuid = "6c6a2e73-6563-6170-7368-637461726353"
version = "1.0.3"

[[Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"

[[SharedArrays]]
deps = ["Distributed", "Mmap", "Random", "Serialization"]
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"

[[Sockets]]
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"

[[SparseArrays]]
deps = ["LinearAlgebra", "Random"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[[SpecialFunctions]]
deps = ["ChainRulesCore", "OpenSpecFun_jll"]
git-tree-sha1 = "5919936c0e92cff40e57d0ddf0ceb667d42e5902"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "1.3.0"

[[StaticArrays]]
deps = ["LinearAlgebra", "Random", "Statistics"]
git-tree-sha1 = "e8cd1b100d37f5b4cfd2c83f45becf61c762eaf7"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.1.1"

[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[[TOML]]
deps = ["Dates"]
uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"

[[Tar]]
deps = ["ArgTools", "SHA"]
uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"

[[Test]]
deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[[TimerOutputs]]
deps = ["Printf"]
git-tree-sha1 = "32cdbe6cd2d214c25a0b88f985c9e0092877c236"
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
version = "0.5.8"

[[UUIDs]]
deps = ["Random", "SHA"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[[Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

[[Zlib_jll]]
deps = ["Libdl"]
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"

[[nghttp2_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"

[[p7zip_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"
8 changes: 8 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
name = "EnzymaticOcean"
uuid = "1e308aaf-769c-447f-af17-2df7aeedce0a"
version = "0.1.0"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CUDAKernels = "72cfdca4-0801-4ab0-bf6a-d52aa10adc57"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
45 changes: 45 additions & 0 deletions examples/simple_ode.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
using KernelAbstractions
using CUDAKernels
using CUDA

@kernel function time_step!(x, k, Δt, N)
I = @index(Global)
for _ in 1:N
@inbounds x[I] = x[I] + Δt * k * x[I]
end
end

"""
simulate_exponential_growth(; x₀, k, Δt, N, device)

Simulates exponential growth according to the ODE

dx/dt = k * x, x(t = 0) = x₀

using forward Euler time stepping with time step `Δt` for `N` iterations.
"""
function simulate_exponential_growth(; x₀, k, Δt, N, device)

if device isa CPU
x = [x₀]
elseif device isa CUDADevice
x = CuArray([x₀])
end

time_step_kernel! = time_step!(device, 1)
event = time_step_kernel!(x, k, Δt, N, ndrange=1)
wait(event)

return CUDA.@allowscalar x[1]
end

# 1 kernel launch
@show simulate_exponential_growth(x₀=1.0, k=2.5, Δt=0.01, N=1, device=CPU())
@show simulate_exponential_growth(x₀=1.0, k=2.5, Δt=0.01, N=1, device=CUDADevice())

# iterative kernel launches
@show simulate_exponential_growth(x₀=1.0, k=2.5, Δt=0.01, N=100, device=CPU())
@show simulate_exponential_growth(x₀=1.0, k=2.5, Δt=0.01, N=100, device=CUDADevice())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These three lines are now also a single kernel launch, right?


# Optimization problem:
# If x₀=1, Δt = 0.01, N = 100, and x(t = 1) = π then what is k?
Copy link
Collaborator

@glwagner glwagner Apr 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add something like

using Enzyme

simulate_exponential_growth(k) = simulate_exponential_growth(; x₀=1.0, k=k, Δt=0.01, N=100, device=CPU())

@show autodiff(simulate_exponential_growth, Active(1.0))

to see what happens?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I realized I was a little confused. We need to define a loss function and take the gradient of that. Perhaps we can define a k_truth, and compare the result of the simulation at a given k_test? I'll make a comment on the kernel.

5 changes: 5 additions & 0 deletions src/EnzymaticOcean.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
module EnzymaticOcean

greet() = print("Hello World!")

end # module