diff --git a/README.md b/README.md index ef7357a1..6a8e4f75 100644 --- a/README.md +++ b/README.md @@ -77,7 +77,7 @@ Flux v0.13 is the latest right now, marked with ☀️; models upgraded to use * [ConvMixer "Patches are all you need?"](vision/convmixer_cifar10/) ☀️ v0.13 **Text** -* [CharRNN](text/char-rnn) ☀️ v0.13.9 +* [CharRNN](text/char-rnn) ☀️ v0.13 + * [Character-level language detection](text/lang-detection) ⛅️ v0.11 * [Seq2Seq phoneme detection on CMUDict](text/phonemes) ⛅️ v0.11 * [Recursive net on IMDB sentiment treebank](text/treebank) ⛅️ v0.11 diff --git a/text/char-rnn/Manifest.toml b/text/char-rnn/Manifest.toml index 6d60fd26..ee67b2c5 100644 --- a/text/char-rnn/Manifest.toml +++ b/text/char-rnn/Manifest.toml @@ -1,404 +1,673 @@ # This file is machine-generated - editing it directly is not advised -[[AbstractFFTs]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "051c95d6836228d120f5f4b984dd5aba1624f716" +julia_version = "1.8.2" +manifest_format = "2.0" +project_hash = "99d33959b3ed7c25da1bc930016e7f91224c9e26" + +[[deps.AbstractFFTs]] +deps = ["ChainRulesCore", "LinearAlgebra"] +git-tree-sha1 = "69f7020bd72f069c219b5e8c236c1fa90d2cb409" uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" -version = "0.5.0" +version = "1.2.1" -[[AbstractTrees]] -deps = ["Markdown"] -git-tree-sha1 = "33e450545eaf7699da1a6e755f9ea65f14077a45" -uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" -version = "0.3.3" +[[deps.Accessors]] +deps = ["Compat", "CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Requires", "Test"] +git-tree-sha1 = "3fa8cc751763c91a5ea33331e523221009cb1e6f" +uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" +version = "0.1.23" -[[Adapt]] +[[deps.Adapt]] deps = ["LinearAlgebra"] -git-tree-sha1 = "345a14764e43fe927d6f5c250fe4c8e4664e6ee8" +git-tree-sha1 = "195c5505521008abea5aee4f96930717958eac6f" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "2.4.0" +version = "3.4.0" -[[Artifacts]] -deps = ["Pkg"] -git-tree-sha1 = "c30985d8821e0cd73870b17b0ed0ce6dc44cb744" +[[deps.ArgCheck]] +git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" +uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197" +version = "2.3.0" + +[[deps.ArgTools]] +uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" +version = "1.1.1" + +[[deps.Artifacts]] uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" -version = "1.3.0" -[[BFloat16s]] -deps = ["LinearAlgebra", "Test"] -git-tree-sha1 = "4af69e205efc343068dc8722b8dfec1ade89254a" +[[deps.BFloat16s]] +deps = ["LinearAlgebra", "Printf", "Random", "Test"] +git-tree-sha1 = "a598ecb0d717092b5539dbbe890c98bac842b072" uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" -version = "0.1.0" +version = "0.2.0" + +[[deps.BangBang]] +deps = ["Compat", "ConstructionBase", "Future", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables", "ZygoteRules"] +git-tree-sha1 = "7fe6d92c4f281cf4ca6f2fba0ce7b299742da7ca" +uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" +version = "0.3.37" -[[Base64]] +[[deps.Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" -[[CEnum]] -git-tree-sha1 = "215a9aa4a1f23fbd05b92769fdd62559488d70e9" +[[deps.Baselet]] +git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e" +uuid = "9718e550-a3fa-408a-8086-8db961cd8217" +version = "0.1.1" + +[[deps.CEnum]] +git-tree-sha1 = "eb4cb44a499229b3b8426dcfb5dd85333951ff90" uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" -version = "0.4.1" +version = "0.4.2" -[[CUDA]] -deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "DataStructures", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "Libdl", "LinearAlgebra", "Logging", "MacroTools", "NNlib", "Pkg", "Printf", "Random", "Reexport", "Requires", "SparseArrays", "Statistics", "TimerOutputs"] -git-tree-sha1 = "39f6f584bec264ace76f924d1c8637c85617697e" +[[deps.CUDA]] +deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"] +git-tree-sha1 = "a56dff7bc49b5d5ac43d2c10eb2aef94becd5251" uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" -version = "2.4.0" +version = "3.12.1" -[[ChainRules]] -deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Random", "Reexport", "Requires", "Statistics"] -git-tree-sha1 = "0af5c12e5528fc2df87a5f084195f10bfbf03a28" +[[deps.ChainRules]] +deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics", "StructArrays"] +git-tree-sha1 = "c46adabdd0348f0ee8de91142cfc4a72a613ac0a" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.48" +version = "1.46.1" -[[ChainRulesCore]] +[[deps.ChainRulesCore]] deps = ["Compat", "LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "89a0b14325d0f02f9caed7c8ba91181a5d254874" +git-tree-sha1 = "e7ff6cadf743c098e08fca25c91103ee4303c9bb" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.9.26" +version = "1.15.6" -[[CodecZlib]] -deps = ["TranscodingStreams", "Zlib_jll"] -git-tree-sha1 = "ded953804d019afa9a3f98981d99b33e3db7b6da" -uuid = "944b1d66-785c-5afd-91f1-9de20f533193" -version = "0.7.0" - -[[ColorTypes]] -deps = ["FixedPointNumbers", "Random"] -git-tree-sha1 = "4bffea7ed1a9f0f3d1a131bbcd4b925548d75288" -uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" -version = "0.10.9" +[[deps.ChangesOfVariables]] +deps = ["ChainRulesCore", "LinearAlgebra", "Test"] +git-tree-sha1 = "38f7a08f19d8810338d4f5085211c7dfa5d5bdd8" +uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" +version = "0.1.4" -[[Colors]] -deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Reexport"] -git-tree-sha1 = "ac5f2213e56ed8a34a3dd2f681f4df1166b34929" -uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" -version = "0.12.6" - -[[CommonSubexpressions]] +[[deps.CommonSubexpressions]] deps = ["MacroTools", "Test"] git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" version = "0.3.0" -[[Compat]] -deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] -git-tree-sha1 = "919c7f3151e79ff196add81d7f4e45d91bbf420b" +[[deps.Compat]] +deps = ["Dates", "LinearAlgebra", "UUIDs"] +git-tree-sha1 = "00a2cccc7f098ff3b66806862d275ca3db9e6e5a" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "3.25.0" +version = "4.5.0" -[[CompilerSupportLibraries_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "8e695f735fca77e9708e795eda62afdb869cbb70" +[[deps.CompilerSupportLibraries_jll]] +deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "0.3.4+0" +version = "0.5.2+0" + +[[deps.CompositionsBase]] +git-tree-sha1 = "455419f7e328a1a2493cabc6428d79e951349769" +uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" +version = "0.1.1" + +[[deps.ConstructionBase]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "fb21ddd70a051d882a1686a5a550990bbe371a95" +uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" +version = "1.4.1" + +[[deps.ContextVariablesX]] +deps = ["Compat", "Logging", "UUIDs"] +git-tree-sha1 = "25cc3803f1030ab855e383129dcd3dc294e322cc" +uuid = "6add18c4-b38d-439d-96f6-d6bc489c04c5" +version = "0.1.3" -[[DataAPI]] -git-tree-sha1 = "ad84f52c0b8f05aa20839484dbaf01690b41ff84" +[[deps.DataAPI]] +git-tree-sha1 = "e8119c1a33d267e16108be441a287a6981ba1630" uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.4.0" +version = "1.14.0" -[[DataStructures]] +[[deps.DataStructures]] deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "fb0aa371da91c1ff9dc7fbed6122d3e411420b9c" +git-tree-sha1 = "d1fff3a548102f48987a52a2e0d114fa97d730f0" uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.8" +version = "0.18.13" + +[[deps.DataValueInterfaces]] +git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" +uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" +version = "1.0.0" -[[Dates]] +[[deps.Dates]] deps = ["Printf"] uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" -[[DelimitedFiles]] +[[deps.DefineSingletons]] +git-tree-sha1 = "0fba8b706d0178b4dc7fd44a96a92382c9065c2c" +uuid = "244e2a9f-e319-4986-a169-4d1fe445cd52" +version = "0.1.2" + +[[deps.DelimitedFiles]] deps = ["Mmap"] uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" -[[DiffResults]] -deps = ["StaticArrays"] -git-tree-sha1 = "c18e98cba888c6c25d1c3b048e4b3380ca956805" +[[deps.DiffResults]] +deps = ["StaticArraysCore"] +git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621" uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" -version = "1.0.3" +version = "1.1.0" -[[DiffRules]] -deps = ["NaNMath", "Random", "SpecialFunctions"] -git-tree-sha1 = "214c3fcac57755cfda163d91c58893a8723f93e9" +[[deps.DiffRules]] +deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] +git-tree-sha1 = "c5b6685d53f933c11404a3ae9822afe30d522494" uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "1.0.2" +version = "1.12.2" -[[Distributed]] +[[deps.Distributed]] deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" -[[ExprTools]] -git-tree-sha1 = "10407a39b87f29d47ebaca8edbc75d7c302ff93e" +[[deps.DocStringExtensions]] +deps = ["LibGit2"] +git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" +uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +version = "0.9.3" + +[[deps.Downloads]] +deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] +uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" +version = "1.6.0" + +[[deps.ExprTools]] +git-tree-sha1 = "56559bbef6ca5ea0c0818fa5c90320398a6fbf8d" uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" -version = "0.1.3" +version = "0.1.8" -[[FillArrays]] -deps = ["LinearAlgebra", "Random", "SparseArrays"] -git-tree-sha1 = "8bd8e47ff5d34b20f0aa9641988eb660590008bc" -uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "0.11.0" +[[deps.FLoops]] +deps = ["BangBang", "Compat", "FLoopsBase", "InitialValues", "JuliaVariables", "MLStyle", "Serialization", "Setfield", "Transducers"] +git-tree-sha1 = "ffb97765602e3cbe59a0589d237bf07f245a8576" +uuid = "cc61a311-1640-44b5-9fba-1b764f453329" +version = "0.2.1" + +[[deps.FLoopsBase]] +deps = ["ContextVariablesX"] +git-tree-sha1 = "656f7a6859be8673bf1f35da5670246b923964f7" +uuid = "b9860ae5-e623-471e-878b-f6a53c775ea6" +version = "0.1.1" -[[FixedPointNumbers]] -deps = ["Statistics"] -git-tree-sha1 = "335bfdceacc84c5cdf16aadc768aa5ddfc5383cc" -uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" -version = "0.8.4" +[[deps.FileWatching]] +uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" -[[Flux]] -deps = ["AbstractTrees", "Adapt", "CUDA", "CodecZlib", "Colors", "DelimitedFiles", "Functors", "Juno", "LinearAlgebra", "MacroTools", "NNlib", "Pkg", "Printf", "Random", "Reexport", "SHA", "Statistics", "StatsBase", "Test", "ZipFile", "Zygote"] -git-tree-sha1 = "f688d61b40b345aa9f0a4a41d3ca7750ad9cb1f6" +[[deps.FillArrays]] +deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] +git-tree-sha1 = "9a0472ec2f5409db243160a8b030f94c380167a3" +uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" +version = "0.13.6" + +[[deps.Flux]] +deps = ["Adapt", "CUDA", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NNlibCUDA", "OneHotArrays", "Optimisers", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "Zygote"] +git-tree-sha1 = "518b553ec3776dde058aebd2750c109d04ee5593" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.11.4" +version = "0.13.11" + +[[deps.FoldsThreads]] +deps = ["Accessors", "FunctionWrappers", "InitialValues", "SplittablesBase", "Transducers"] +git-tree-sha1 = "eb8e1989b9028f7e0985b4268dabe94682249025" +uuid = "9c68100b-dfe1-47cf-94c8-95104e173443" +version = "0.1.1" -[[ForwardDiff]] -deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "NaNMath", "Random", "SpecialFunctions", "StaticArrays"] -git-tree-sha1 = "c26b56e9b9f0687f7ca887f6b6ded03d269e0e35" +[[deps.ForwardDiff]] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions", "StaticArrays"] +git-tree-sha1 = "a69dd6db8a809f78846ff259298678f0d6212180" uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.15" +version = "0.10.34" -[[Functors]] -deps = ["MacroTools"] -git-tree-sha1 = "f40adc6422f548176bb4351ebd29e4abf773040a" +[[deps.FunctionWrappers]] +git-tree-sha1 = "d62485945ce5ae9c0c48f124a84998d755bae00e" +uuid = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" +version = "1.1.3" + +[[deps.Functors]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "993c2b4a9a54496b6d8e265db1244db418f37e01" uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -version = "0.1.0" +version = "0.4.1" + +[[deps.Future]] +deps = ["Random"] +uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" -[[GPUArrays]] -deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization"] -git-tree-sha1 = "f99a25fe0313121f2f9627002734c7d63b4dd3bd" +[[deps.GPUArrays]] +deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] +git-tree-sha1 = "45d7deaf05cbb44116ba785d147c518ab46352d7" uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "6.2.0" +version = "8.5.0" -[[GPUCompiler]] -deps = ["DataStructures", "InteractiveUtils", "LLVM", "Libdl", "Scratch", "Serialization", "TimerOutputs", "UUIDs"] -git-tree-sha1 = "c853c810b52a80f9aad79ab109207889e57f41ef" +[[deps.GPUArraysCore]] +deps = ["Adapt"] +git-tree-sha1 = "6872f5ec8fd1a38880f027a26739d42dcda6691f" +uuid = "46192b85-c4d5-4398-a991-12ede77f4527" +version = "0.1.2" + +[[deps.GPUCompiler]] +deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "TimerOutputs", "UUIDs"] +git-tree-sha1 = "48832a7cacbe56e591a7bef690c78b9d00bcc692" uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" -version = "0.8.3" +version = "0.17.1" -[[IRTools]] +[[deps.IRTools]] deps = ["InteractiveUtils", "MacroTools", "Test"] -git-tree-sha1 = "c67e7515a11f726f44083e74f218d134396d6510" +git-tree-sha1 = "2e99184fca5eb6f075944b04c22edec29beb4778" uuid = "7869d1d1-7146-5819-86e3-90919afe41df" -version = "0.4.2" +version = "0.4.7" + +[[deps.InitialValues]] +git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3" +uuid = "22cec73e-a1b8-11e9-2c92-598750a2cf9c" +version = "0.3.1" -[[InteractiveUtils]] +[[deps.InteractiveUtils]] deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" -[[JLLWrappers]] -git-tree-sha1 = "a431f5f2ca3f4feef3bd7a5e94b8b8d4f2f647a0" +[[deps.InverseFunctions]] +deps = ["Test"] +git-tree-sha1 = "49510dfcb407e572524ba94aeae2fced1f3feb0f" +uuid = "3587e190-3f89-42d0-90ee-14403ec27112" +version = "0.1.8" + +[[deps.IrrationalConstants]] +git-tree-sha1 = "7fd44fd4ff43fc60815f8e764c0f352b83c49151" +uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" +version = "0.1.1" + +[[deps.IteratorInterfaceExtensions]] +git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" +uuid = "82899510-4779-5014-852e-03e436cf321d" +version = "1.0.0" + +[[deps.JLLWrappers]] +deps = ["Preferences"] +git-tree-sha1 = "abc9885a7ca2052a736a600f7fa66209f96506e1" uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.2.0" +version = "1.4.1" -[[Juno]] -deps = ["Base64", "Logging", "Media", "Profile"] -git-tree-sha1 = "07cb43290a840908a771552911a6274bc6c072c7" -uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d" -version = "0.8.4" +[[deps.JuliaVariables]] +deps = ["MLStyle", "NameResolution"] +git-tree-sha1 = "49fb3cb53362ddadb4415e9b73926d6b40709e70" +uuid = "b14d175d-62b4-44ba-8fb7-3064adc8c3ec" +version = "0.2.4" -[[LLVM]] -deps = ["CEnum", "Libdl", "Printf", "Unicode"] -git-tree-sha1 = "d0d99629d6ae4a3e211ae83d8870907bd842c811" +[[deps.LLVM]] +deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"] +git-tree-sha1 = "088dd02b2797f0233d92583562ab669de8517fd1" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "3.5.2" - -[[LibGit2]] -deps = ["Printf"] +version = "4.14.1" + +[[deps.LLVMExtra_jll]] +deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg", "TOML"] +git-tree-sha1 = "771bfe376249626d3ca12bcd58ba243d3f961576" +uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" +version = "0.0.16+0" + +[[deps.LazyArtifacts]] +deps = ["Artifacts", "Pkg"] +uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" + +[[deps.LibCURL]] +deps = ["LibCURL_jll", "MozillaCACerts_jll"] +uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" +version = "0.6.3" + +[[deps.LibCURL_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] +uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" +version = "7.84.0+0" + +[[deps.LibGit2]] +deps = ["Base64", "NetworkOptions", "Printf", "SHA"] uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" -[[Libdl]] +[[deps.LibSSH2_jll]] +deps = ["Artifacts", "Libdl", "MbedTLS_jll"] +uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" +version = "1.10.2+0" + +[[deps.Libdl]] uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" -[[LinearAlgebra]] -deps = ["Libdl"] +[[deps.LinearAlgebra]] +deps = ["Libdl", "libblastrampoline_jll"] uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -[[Logging]] +[[deps.LogExpFunctions]] +deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"] +git-tree-sha1 = "946607f84feb96220f480e0422d3484c49c00239" +uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +version = "0.3.19" + +[[deps.Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" -[[MacroTools]] +[[deps.MLStyle]] +git-tree-sha1 = "060ef7956fef2dc06b0e63b294f7dbfbcbdc7ea2" +uuid = "d8e11817-5142-5d16-987a-aa16d5891078" +version = "0.4.16" + +[[deps.MLUtils]] +deps = ["ChainRulesCore", "Compat", "DataAPI", "DelimitedFiles", "FLoops", "FoldsThreads", "NNlib", "Random", "ShowCases", "SimpleTraits", "Statistics", "StatsBase", "Tables", "Transducers"] +git-tree-sha1 = "266c67f773feb756474c2c4a7424ea5363290300" +uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" +version = "0.4.0" + +[[deps.MacroTools]] deps = ["Markdown", "Random"] -git-tree-sha1 = "6a8a2a625ab0dea913aba95c11370589e0239ff0" +git-tree-sha1 = "42324d08725e200c23d4dfb549e0d5d89dede2d2" uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.6" +version = "0.5.10" -[[Markdown]] +[[deps.Markdown]] deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" -[[Media]] -deps = ["MacroTools", "Test"] -git-tree-sha1 = "75a54abd10709c01f1b86b84ec225d26e840ed58" -uuid = "e89f7d12-3494-54d1-8411-f7d8b9ae1f27" -version = "0.5.0" +[[deps.MbedTLS_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" +version = "2.28.0+0" -[[Missings]] +[[deps.MicroCollections]] +deps = ["BangBang", "InitialValues", "Setfield"] +git-tree-sha1 = "4d5917a26ca33c66c8e5ca3247bd163624d35493" +uuid = "128add7d-3638-4c79-886c-908ea0c25c34" +version = "0.1.3" + +[[deps.Missings]] deps = ["DataAPI"] -git-tree-sha1 = "ed61674a0864832495ffe0a7e889c0da76b0f4c8" +git-tree-sha1 = "f66bdc5de519e8f8ae43bdc598782d35a25b1272" uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" -version = "0.4.4" +version = "1.1.0" -[[Mmap]] +[[deps.Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" -[[NNlib]] -deps = ["ChainRulesCore", "LinearAlgebra", "Pkg", "Requires", "Statistics"] -git-tree-sha1 = "13fd29731c7f609cb82a3a544c5538584d22c153" +[[deps.MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" +version = "2022.2.1" + +[[deps.NNlib]] +deps = ["Adapt", "ChainRulesCore", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] +git-tree-sha1 = "03541c7a6dc3010cb2e33a01295f3cd35b5fd41e" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.7.11" +version = "0.8.15" + +[[deps.NNlibCUDA]] +deps = ["Adapt", "CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics"] +git-tree-sha1 = "b05a082b08a3af0e5c576883bc6dfb6513e7e478" +uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d" +version = "0.2.6" -[[NaNMath]] -git-tree-sha1 = "bfe47e760d60b82b66b61d2d44128b62e3a369fb" +[[deps.NaNMath]] +deps = ["OpenLibm_jll"] +git-tree-sha1 = "a7c3d1da1189a1c2fe843a3bfa04d18d20eb3211" uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" -version = "0.3.5" +version = "1.0.1" + +[[deps.NameResolution]] +deps = ["PrettyPrint"] +git-tree-sha1 = "1a0fa0e9613f46c9b8c11eee38ebb4f590013c5e" +uuid = "71a1bf82-56d0-4bbc-8a3c-48b961074391" +version = "0.1.5" -[[OpenSpecFun_jll]] +[[deps.NetworkOptions]] +uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" +version = "1.2.0" + +[[deps.OneHotArrays]] +deps = ["Adapt", "ChainRulesCore", "Compat", "GPUArraysCore", "LinearAlgebra", "NNlib"] +git-tree-sha1 = "f511fca956ed9e70b80cd3417bb8c2dde4b68644" +uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" +version = "0.2.3" + +[[deps.OpenBLAS_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] +uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" +version = "0.3.20+0" + +[[deps.OpenLibm_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "05823500-19ac-5b8b-9628-191a04bc5112" +version = "0.8.1+0" + +[[deps.OpenSpecFun_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "9db77584158d0ab52307f8c04f8e7c08ca76b5b3" +git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" -version = "0.5.3+4" +version = "0.5.5+0" -[[OrderedCollections]] -git-tree-sha1 = "cf59cfed2e2c12e8a2ff0a4f1e9b2cd8650da6db" -uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.3.2" +[[deps.Optimisers]] +deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "e657acef119cc0de2a8c0762666d3b64727b053b" +uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" +version = "0.2.14" -[[Parameters]] -deps = ["OrderedCollections", "UnPack"] -git-tree-sha1 = "38b2e970043613c187bd56a995fe2e551821eb4a" -uuid = "d96e819e-fc66-5662-9728-84c9c7592b0a" -version = "0.12.1" +[[deps.OrderedCollections]] +git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.4.1" -[[Pkg]] -deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] +[[deps.Pkg]] +deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +version = "1.8.0" -[[Printf]] +[[deps.Preferences]] +deps = ["TOML"] +git-tree-sha1 = "47e5f437cc0e7ef2ce8406ce1e7e24d44915f88d" +uuid = "21216c6a-2e73-6563-6e65-726566657250" +version = "1.3.0" + +[[deps.PrettyPrint]] +git-tree-sha1 = "632eb4abab3449ab30c5e1afaa874f0b98b586e4" +uuid = "8162dcfd-2161-5ef2-ae6c-7681170c5f98" +version = "0.2.0" + +[[deps.Printf]] deps = ["Unicode"] uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" -[[Profile]] -deps = ["Printf"] -uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79" +[[deps.ProgressLogging]] +deps = ["Logging", "SHA", "UUIDs"] +git-tree-sha1 = "80d919dee55b9c50e8d9e2da5eeafff3fe58b539" +uuid = "33c8b6b6-d38a-422a-b730-caa89a2f386c" +version = "0.1.4" -[[REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets"] +[[deps.REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" -[[Random]] -deps = ["Serialization"] +[[deps.Random]] +deps = ["SHA", "Serialization"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -[[Reexport]] -deps = ["Pkg"] -git-tree-sha1 = "7b1d07f411bc8ddb7977ec7f377b97b158514fe0" +[[deps.Random123]] +deps = ["Random", "RandomNumbers"] +git-tree-sha1 = "7a1a306b72cfa60634f03a911405f4e64d1b718b" +uuid = "74087812-796a-5b5d-8853-05524746bad3" +version = "1.6.0" + +[[deps.RandomNumbers]] +deps = ["Random", "Requires"] +git-tree-sha1 = "043da614cc7e95c703498a491e2c21f58a2b8111" +uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143" +version = "1.5.3" + +[[deps.RealDot]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9" +uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" +version = "0.1.0" + +[[deps.Reexport]] +git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" uuid = "189a3867-3050-52da-a836-e630ba90ab69" -version = "0.2.0" +version = "1.2.2" -[[Requires]] +[[deps.Requires]] deps = ["UUIDs"] -git-tree-sha1 = "cfbac6c1ed70c002ec6361e7fd334f02820d6419" +git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" uuid = "ae029012-a4dd-5104-9daa-d747884805df" -version = "1.1.2" +version = "1.3.0" -[[SHA]] +[[deps.SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" +version = "0.7.0" -[[Scratch]] -deps = ["Dates"] -git-tree-sha1 = "ad4b278adb62d185bbcb6864dc24959ab0627bf6" -uuid = "6c6a2e73-6563-6170-7368-637461726353" -version = "1.0.3" - -[[Serialization]] +[[deps.Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" -[[SharedArrays]] -deps = ["Distributed", "Mmap", "Random", "Serialization"] -uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" +[[deps.Setfield]] +deps = ["ConstructionBase", "Future", "MacroTools", "StaticArraysCore"] +git-tree-sha1 = "e2cc6d8c88613c05e1defb55170bf5ff211fbeac" +uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46" +version = "1.1.1" + +[[deps.ShowCases]] +git-tree-sha1 = "7f534ad62ab2bd48591bdeac81994ea8c445e4a5" +uuid = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" +version = "0.1.0" -[[Sockets]] +[[deps.SimpleTraits]] +deps = ["InteractiveUtils", "MacroTools"] +git-tree-sha1 = "5d7e3f4e11935503d3ecaf7186eac40602e7d231" +uuid = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" +version = "0.9.4" + +[[deps.Sockets]] uuid = "6462fe0b-24de-5631-8697-dd941f90decc" -[[SortingAlgorithms]] -deps = ["DataStructures", "Random", "Test"] -git-tree-sha1 = "03f5898c9959f8115e30bc7226ada7d0df554ddd" +[[deps.SortingAlgorithms]] +deps = ["DataStructures"] +git-tree-sha1 = "a4ada03f999bd01b3a25dcaa30b2d929fe537e00" uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" -version = "0.3.1" +version = "1.1.0" -[[SparseArrays]] +[[deps.SparseArrays]] deps = ["LinearAlgebra", "Random"] uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -[[SpecialFunctions]] -deps = ["ChainRulesCore", "OpenSpecFun_jll"] -git-tree-sha1 = "75394dbe2bd346beeed750fb02baa6445487b862" +[[deps.SpecialFunctions]] +deps = ["ChainRulesCore", "IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] +git-tree-sha1 = "d75bda01f8c31ebb72df80a46c88b25d1c79c56d" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "1.2.1" +version = "2.1.7" -[[StaticArrays]] -deps = ["LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "9da72ed50e94dbff92036da395275ed114e04d49" +[[deps.SplittablesBase]] +deps = ["Setfield", "Test"] +git-tree-sha1 = "e08a62abc517eb79667d0a29dc08a3b589516bb5" +uuid = "171d559e-b47b-412a-8079-5efa626c420e" +version = "0.1.15" + +[[deps.StaticArrays]] +deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"] +git-tree-sha1 = "6954a456979f23d05085727adb17c4551c19ecd1" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.0.1" +version = "1.5.12" + +[[deps.StaticArraysCore]] +git-tree-sha1 = "6b7ba252635a5eff6a0b0664a41ee140a1c9e72a" +uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +version = "1.4.0" -[[Statistics]] +[[deps.Statistics]] deps = ["LinearAlgebra", "SparseArrays"] uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -[[StatsBase]] -deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"] -git-tree-sha1 = "7bab7d4eb46b225b35179632852b595a3162cb61" +[[deps.StatsAPI]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "f9af7f195fb13589dd2e2d57fdb401717d2eb1f6" +uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" +version = "1.5.0" + +[[deps.StatsBase]] +deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] +git-tree-sha1 = "d1bf48bfcc554a3761a133fe3a9bb01488e06916" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.33.2" +version = "0.33.21" -[[Test]] -deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] +[[deps.StructArrays]] +deps = ["Adapt", "DataAPI", "GPUArraysCore", "StaticArraysCore", "Tables"] +git-tree-sha1 = "b03a3b745aa49b566f128977a7dd1be8711c5e71" +uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" +version = "0.6.14" + +[[deps.TOML]] +deps = ["Dates"] +uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" +version = "1.0.0" + +[[deps.TableTraits]] +deps = ["IteratorInterfaceExtensions"] +git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" +uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" +version = "1.0.1" + +[[deps.Tables]] +deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits", "Test"] +git-tree-sha1 = "c79322d36826aa2f4fd8ecfa96ddb47b174ac78d" +uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" +version = "1.10.0" + +[[deps.Tar]] +deps = ["ArgTools", "SHA"] +uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" +version = "1.10.1" + +[[deps.Test]] +deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -[[TimerOutputs]] -deps = ["Printf"] -git-tree-sha1 = "3318281dd4121ecf9713ce1383b9ace7d7476fdd" +[[deps.TimerOutputs]] +deps = ["ExprTools", "Printf"] +git-tree-sha1 = "f2fd3f288dfc6f507b0c3a2eb3bac009251e548b" uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" -version = "0.5.7" +version = "0.5.22" -[[TranscodingStreams]] -deps = ["Random", "Test"] -git-tree-sha1 = "7c53c35547de1c5b9d46a4797cf6d8253807108c" -uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -version = "0.9.5" +[[deps.Transducers]] +deps = ["Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"] +git-tree-sha1 = "c42fa452a60f022e9e087823b47e5a5f8adc53d5" +uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" +version = "0.4.75" -[[UUIDs]] +[[deps.UUIDs]] deps = ["Random", "SHA"] uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" -[[UnPack]] -git-tree-sha1 = "387c1f73762231e86e0c9c5443ce3b4a0a9a0c2b" -uuid = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" -version = "1.0.2" - -[[Unicode]] +[[deps.Unicode]] uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" -[[ZipFile]] -deps = ["Libdl", "Printf", "Zlib_jll"] -git-tree-sha1 = "c3a5637e27e914a7a445b8d0ad063d701931e9f7" -uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" -version = "0.9.3" - -[[Zlib_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "320228915c8debb12cb434c59057290f0834dbf6" +[[deps.Zlib_jll]] +deps = ["Libdl"] uuid = "83775a58-1f1d-513f-b197-d71354ab007a" -version = "1.2.11+18" +version = "1.2.12+3" -[[Zygote]] -deps = ["AbstractFFTs", "ChainRules", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "52032f3eb3bf383df34f5455c031457632e8c6d4" +[[deps.Zygote]] +deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] +git-tree-sha1 = "60c4588669caf2a23363e0b191d3a459773e91b4" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.1" +version = "0.6.53" -[[ZygoteRules]] +[[deps.ZygoteRules]] deps = ["MacroTools"] -git-tree-sha1 = "9e7a1e8ca60b742e508a315c17eef5211e7fbfd7" +git-tree-sha1 = "8c1a8e4dfacb1fd631745552c8db35d0deb09ea0" uuid = "700de1a5-db45-46bc-99cf-38207098b444" -version = "0.2.1" +version = "0.2.2" + +[[deps.libblastrampoline_jll]] +deps = ["Artifacts", "Libdl", "OpenBLAS_jll"] +uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" +version = "5.1.1+0" + +[[deps.nghttp2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" +version = "1.48.0+0" + +[[deps.p7zip_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" +version = "17.4.0+0" diff --git a/text/char-rnn/Project.toml b/text/char-rnn/Project.toml index 3eaad834..50eded9f 100644 --- a/text/char-rnn/Project.toml +++ b/text/char-rnn/Project.toml @@ -1,6 +1,6 @@ [deps] Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" -Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" +OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [compat] diff --git a/text/char-rnn/char-rnn.jl b/text/char-rnn/char-rnn.jl index 354ee483..024d9a62 100644 --- a/text/char-rnn/char-rnn.jl +++ b/text/char-rnn/char-rnn.jl @@ -26,15 +26,15 @@ # To run this example, we need the following packages: using Flux -using Flux: onehot, chunk, batchseq, throttle, logitcrossentropy +using Flux: chunk, batchseq, logitcrossentropy +using OneHotArrays using StatsBase: wsample using Base.Iterators: partition -using Parameters: @with_kw using Random: shuffle # We set default values for the hyperparameters: -@with_kw mutable struct Args +Base.@kwdef mutable struct Args lr::Float64 = 1e-2 # Learning rate seqlen::Int = 50 # Length of batch sequences batchsz::Int = 50 # Number of sequences in each batch @@ -49,7 +49,7 @@ end # for training the model: -function getdata(args) +function getdata(args::Args) ## Download the data if not downloaded as 'input.txt' isfile("input.txt") || download( "https://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt", @@ -87,11 +87,11 @@ end # We create the RNN with two Flux’s LSTM layers and an output layer of the size of the alphabet: -function build_model(N) +function build_model(N::Int) return Chain( - LSTM(N, 128), - LSTM(128, 128), - Dense(128, N)) + LSTM(N => 128), + LSTM(128 => 128), + Dense(128 => N)) end # The size of the input and output layers is the same as the size of the alphabet. @@ -124,30 +124,29 @@ function train(; kws...) trainX, trainY, testX, testY = device.((trainX, trainY, testX, testY)) ## Constructing Model - m = build_model(N) |> device + model = build_model(N) |> device - function loss(xs, ys) + function loss(m, xs, ys) Flux.reset!(m) return sum(logitcrossentropy.([m(x) for x in xs], ys)) end ## Training - opt = ADAM(args.lr) + opt_state = Flux.setup(Adam(args.lr), model) - @info "Start Training, total $(args.epochs) epochs" for epoch = 1:args.epochs - @info "Epoch $(epoch) / $(args.epochs)" + @info "Training, epoch $(epoch) / $(args.epochs)" Flux.train!( loss, - Flux.params(m), + model, zip(trainX, trainY), - opt + opt_state ) ## Show loss-per-character over the test set - @show sum(loss.(testX, testY)) / (args.batchsz * args.seqlen * length(testX)) + @show sum(loss.(Ref(model), testX, testY)) / (args.batchsz * args.seqlen * length(testX)) end - return m, alphabet + return model, alphabet end # The function `train` performs the following tasks: @@ -160,8 +159,6 @@ end # * Sets the [ADAM optimiser](https://fluxml.ai/Flux.jl/stable/training/optimisers/#Flux.Optimise.RADAM) with the learning rate *lr* we defined above. # * Creates a [callback](https://fluxml.ai/Flux.jl/stable/training/training/#Callbacks) *evalcb* so that you can observe the training process (print the loss value). # * Runs the training loop using [Flux’s train!](https://fluxml.ai/Flux.jl/stable/training/training/#Flux.Optimise.train!). -# It uses the function [throttle](https://fluxml.ai/Flux.jl/stable/utilities/#Flux.throttle) so that the callback *evalcb* -# can only be triggered at most once during timeout seconds (as defined above). # ## Test the model