From 273b9faa5e4e6cbda2d0b877f56de984fe1c99d4 Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+inflectrix@users.noreply.github.com> Date: Tue, 16 Apr 2024 15:23:07 +0000 Subject: [PATCH 01/27] rename CONTRIBUTING --- CONTRIBUTING => CONTRIBUTING.md | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename CONTRIBUTING => CONTRIBUTING.md (100%) diff --git a/CONTRIBUTING b/CONTRIBUTING.md similarity index 100% rename from CONTRIBUTING rename to CONTRIBUTING.md From 3f50d080e9d248506da5feac73949d9d656210ce Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+inflectrix@users.noreply.github.com> Date: Tue, 16 Apr 2024 15:23:57 +0000 Subject: [PATCH 02/27] add flush_state to readme example --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index ad775e2..1a97864 100644 --- a/README.md +++ b/README.md @@ -58,6 +58,8 @@ fn fitness(dna: &MyAgentDNA) -> f32 { let above = n > 0.5; let res = agent.network.predict([n]); + agent.network.flush_state(); + let resi = res.iter().max_index(); if resi == 0 ^ above { From a94198a9bc2776a451bca27211d0c72291df5842 Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+inflectrix@users.noreply.github.com> Date: Tue, 16 Apr 2024 16:14:42 +0000 Subject: [PATCH 03/27] create basic log test --- src/lib.rs | 64 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index ee9f769..ac569b0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,3 +23,67 @@ pub use topology::*; #[cfg(feature = "serde")] pub use nnt_serde::*; + +#[cfg(test)] +mod tests { + use super::*; + use rand::prelude::*; + + #[derive(RandomlyMutable, DivisionReproduction, Clone)] + struct AgentDNA { + network: NeuralNetworkTopology<2, 1>, + } + + impl Prunable for AgentDNA {} + + impl GenerateRandom for AgentDNA { + fn gen_random(rng: &mut impl Rng) -> Self { + Self { + network: NeuralNetworkTopology::new(0.01, 3, rng), + } + } + } + + #[test] + fn basic_test() { + let fitness = |g: &AgentDNA| { + let network = NeuralNetwork::from(&g.network); + let mut fitness = 0.; + let mut rng = rand::thread_rng(); + + for _ in 0..100 { + let n = rng.gen::() * 10000.; + let base = rng.gen::() * 10.; + let expected = n.log(base); + + let [answer] = network.predict([n, base]); + network.flush_state(); + + fitness += 5. / (answer - expected).abs(); + } + + fitness + }; + + let mut rng = rand::thread_rng(); + + let mut sim = GeneticSim::new( + Vec::gen_random(&mut rng, 100), + fitness, + division_pruning_nextgen, + ); + + for _ in 0..100 { + sim.next_generation(); + } + + let mut fits: Vec<_> = sim.genomes + .iter() + .map(fitness) + .collect(); + + fits.sort_by(|a, b| a.partial_cmp(&b).unwrap()); + + dbg!(fits); + } +} \ No newline at end of file From 339b90b3c7970a9b88b162db34551d190d947324 Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+inflectrix@users.noreply.github.com> Date: Tue, 16 Apr 2024 18:05:13 +0000 Subject: [PATCH 04/27] create plotters example --- Cargo.lock | 744 ++++++++++++++++++++++++++++++++++++++++++++++- Cargo.toml | 1 + examples/plot.rs | 135 +++++++++ 3 files changed, 879 insertions(+), 1 deletion(-) create mode 100644 examples/plot.rs diff --git a/Cargo.lock b/Cargo.lock index be4d7b8..5c98cd1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,33 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "adler" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" + +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + +[[package]] +name = "autocfg" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80" + [[package]] name = "bincode" version = "1.3.3" @@ -11,18 +38,135 @@ dependencies = [ "serde", ] +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + [[package]] name = "bitflags" version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" +[[package]] +name = "bumpalo" +version = "3.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" + +[[package]] +name = "bytemuck" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d6d68c57235a3a081186990eca2867354726650f42f7516ca50c28d6281fd15" + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + +[[package]] +name = "cc" +version = "1.0.94" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17f6e324229dc011159fcc089755d1e2e216a90d43a7dea6853ca740b84f35e7" + [[package]] name = "cfg-if" version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "chrono" +version = "0.4.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" +dependencies = [ + "android-tzdata", + "iana-time-zone", + "js-sys", + "num-traits", + "wasm-bindgen", + "windows-targets", +] + +[[package]] +name = "color_quant" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" + +[[package]] +name = "const-cstr" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed3d0b5ff30645a68f35ece8cea4556ca14ef8a1651455f789a099a0513532a6" + +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" + +[[package]] +name = "core-graphics" +version = "0.22.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2581bbab3b8ffc6fcbd550bf46c355135d16e9ff2a6ea032ad6b9bf1d7efe4fb" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "core-graphics-types", + "foreign-types", + "libc", +] + +[[package]] +name = "core-graphics-types" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45390e6114f68f718cc7a830514a96f903cccd70d02a8f6d9f643ac4ba45afaf" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "libc", +] + +[[package]] +name = "core-text" +version = "19.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99d74ada66e07c1cefa18f8abfba765b486f250de2e4a999e5727fc0dd4b4a25" +dependencies = [ + "core-foundation", + "core-graphics", + "foreign-types", + "libc", +] + +[[package]] +name = "crc32fast" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3855a8a784b474f333699ef2bbca9db2c4a1f6d9088a90a2d25b1eb53111eaa" +dependencies = [ + "cfg-if", +] + [[package]] name = "crossbeam-deque" version = "0.8.5" @@ -48,12 +192,140 @@ version = "0.8.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" +[[package]] +name = "dirs-next" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b98cf8ebf19c3d1b223e151f99a4f9f0690dca41414773390fc824184ac833e1" +dependencies = [ + "cfg-if", + "dirs-sys-next", +] + +[[package]] +name = "dirs-sys-next" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ebda144c4fe02d1f7ea1a7d9641b6fc6b580adcfa024ae48797ecdeb6825b4d" +dependencies = [ + "libc", + "redox_users", + "winapi", +] + +[[package]] +name = "dlib" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "330c60081dcc4c72131f8eb70510f1ac07223e5d4163db481a04a0befcffa412" +dependencies = [ + "libloading", +] + +[[package]] +name = "dwrote" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439a1c2ba5611ad3ed731280541d36d2e9c4ac5e7fb818a27b604bdc5a6aa65b" +dependencies = [ + "lazy_static", + "libc", + "winapi", + "wio", +] + [[package]] name = "either" version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" +[[package]] +name = "fdeflate" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f9bfee30e4dedf0ab8b422f03af778d9612b63f502710fc500a334ebe2de645" +dependencies = [ + "simd-adler32", +] + +[[package]] +name = "flate2" +version = "1.0.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46303f565772937ffe1d394a4fac6f411c6013172fadde9dcdb1e147a086940e" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + +[[package]] +name = "float-ord" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7bad48618fdb549078c333a7a8528acb57af271d0433bdecd523eb620628364e" + +[[package]] +name = "font-kit" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21fe28504d371085fae9ac7a3450f0b289ab71e07c8e57baa3fb68b9e57d6ce5" +dependencies = [ + "bitflags 1.3.2", + "byteorder", + "core-foundation", + "core-graphics", + "core-text", + "dirs-next", + "dwrote", + "float-ord", + "freetype", + "lazy_static", + "libc", + "log", + "pathfinder_geometry", + "pathfinder_simd", + "walkdir", + "winapi", + "yeslogic-fontconfig-sys", +] + +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + +[[package]] +name = "freetype" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "efc8599a3078adf8edeb86c71e9f8fa7d88af5ca31e806a867756081f90f5d83" +dependencies = [ + "freetype-sys", + "libc", +] + +[[package]] +name = "freetype-sys" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "66ee28c39a43d89fbed8b4798fb4ba56722cfd2b5af81f9326c27614ba88ecd5" +dependencies = [ + "cc", + "libc", + "pkg-config", +] + [[package]] name = "genetic-rs" version = "0.5.1" @@ -98,12 +370,74 @@ dependencies = [ "wasi", ] +[[package]] +name = "gif" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "80792593675e051cf94a4b111980da2ba60d4a83e43e0048c5693baab3977045" +dependencies = [ + "color_quant", + "weezl", +] + +[[package]] +name = "iana-time-zone" +version = "0.1.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + +[[package]] +name = "image" +version = "0.24.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5690139d2f55868e080017335e4b94cb7414274c74f1669c84fb5feba2c9f69d" +dependencies = [ + "bytemuck", + "byteorder", + "color_quant", + "jpeg-decoder", + "num-traits", + "png", +] + [[package]] name = "itoa" version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" +[[package]] +name = "jpeg-decoder" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f5d4a7da358eff58addd2877a45865158f0d78c911d43a5784ceb7bbf52833b0" + +[[package]] +name = "js-sys" +version = "0.3.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" +dependencies = [ + "wasm-bindgen", +] + [[package]] name = "lazy_static" version = "1.4.0" @@ -116,14 +450,51 @@ version = "0.2.153" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" +[[package]] +name = "libloading" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c2a198fb6b0eada2a8df47933734e6d35d350665a33a3593d7164fa52c75c19" +dependencies = [ + "cfg-if", + "windows-targets", +] + +[[package]] +name = "libredox" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" +dependencies = [ + "bitflags 2.5.0", + "libc", +] + +[[package]] +name = "log" +version = "0.4.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" + +[[package]] +name = "miniz_oxide" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d811f3e15f28568be3407c8e7fdb6514c1cda3cb30683f15b6a1a1dc4ea14a7" +dependencies = [ + "adler", + "simd-adler32", +] + [[package]] name = "neat" version = "0.5.1" dependencies = [ "bincode", - "bitflags", + "bitflags 2.5.0", "genetic-rs", "lazy_static", + "plotters", "rand", "rayon", "serde", @@ -131,6 +502,105 @@ dependencies = [ "serde_json", ] +[[package]] +name = "num-traits" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da0df0e5185db44f69b44f26786fe401b6c293d1907744beaa7fa62b2e5a517a" +dependencies = [ + "autocfg", +] + +[[package]] +name = "once_cell" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" + +[[package]] +name = "pathfinder_geometry" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b7b7e7b4ea703700ce73ebf128e1450eb69c3a8329199ffbfb9b2a0418e5ad3" +dependencies = [ + "log", + "pathfinder_simd", +] + +[[package]] +name = "pathfinder_simd" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebf45976c56919841273f2a0fc684c28437e2f304e264557d9c72be5d5a718be" +dependencies = [ + "rustc_version", +] + +[[package]] +name = "pkg-config" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" + +[[package]] +name = "plotters" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2c224ba00d7cadd4d5c660deaf2098e5e80e07846537c51f9cfa4be50c1fd45" +dependencies = [ + "chrono", + "font-kit", + "image", + "lazy_static", + "num-traits", + "pathfinder_geometry", + "plotters-backend", + "plotters-bitmap", + "plotters-svg", + "ttf-parser", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e76628b4d3a7581389a35d5b6e2139607ad7c75b17aed325f210aa91f4a9609" + +[[package]] +name = "plotters-bitmap" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cebbe1f70205299abc69e8b295035bb52a6a70ee35474ad10011f0a4efb8543" +dependencies = [ + "gif", + "image", + "plotters-backend", +] + +[[package]] +name = "plotters-svg" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38f6d39893cca0701371e3c27294f09797214b86f1fb951b89ade8ec04e2abab" +dependencies = [ + "plotters-backend", +] + +[[package]] +name = "png" +version = "0.17.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06e4b0d3d1312775e782c86c91a111aa1f910cbb65e1337f9975b5f9a554b5e1" +dependencies = [ + "bitflags 1.3.2", + "crc32fast", + "fdeflate", + "flate2", + "miniz_oxide", +] + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -205,18 +675,53 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "redox_users" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd283d9651eeda4b2a83a43c1c91b266c40fd76ecd39a50a8c630ae69dc72891" +dependencies = [ + "getrandom", + "libredox", + "thiserror", +] + [[package]] name = "replace_with" version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3a8614ee435691de62bcffcf4a66d91b3594bf1428a5722e79103249a095690" +[[package]] +name = "rustc_version" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" +dependencies = [ + "semver", +] + [[package]] name = "ryu" version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1" +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "semver" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92d43fe69e652f3df9bdc2b85b2854a0825b86e4fb76bc44d945137d053639ca" + [[package]] name = "serde" version = "1.0.197" @@ -257,6 +762,12 @@ dependencies = [ "serde", ] +[[package]] +name = "simd-adler32" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" + [[package]] name = "syn" version = "2.0.51" @@ -268,14 +779,245 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "thiserror" +version = "1.0.58" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03468839009160513471e86a034bb2c5c0e4baae3b43f79ffc55c4a5427b3297" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.58" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "ttf-parser" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "375812fa44dab6df41c195cd2f7fecb488f6c09fbaafb62807488cefab642bff" + [[package]] name = "unicode-ident" version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "wasm-bindgen" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" +dependencies = [ + "cfg-if", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" +dependencies = [ + "bumpalo", + "log", + "once_cell", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" + +[[package]] +name = "web-sys" +version = "0.3.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77afa9a11836342370f4817622a2f0f418b134426d91a82dfb48f532d2ec13ef" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "weezl" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082" + +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-util" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f29e6f9198ba0d26b4c9f07dbe6f9ed633e1f3d5b8b414090084349e46a52596" +dependencies = [ + "winapi", +] + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + +[[package]] +name = "windows-core" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" + +[[package]] +name = "wio" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d129932f4644ac2396cb456385cbf9e63b5b30c6e8dc4820bdca4eb082037a5" +dependencies = [ + "winapi", +] + +[[package]] +name = "yeslogic-fontconfig-sys" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2bbd69036d397ebbff671b1b8e4d918610c181c5a16073b96f984a38d08c386" +dependencies = [ + "const-cstr", + "dlib", + "once_cell", + "pkg-config", +] diff --git a/Cargo.toml b/Cargo.toml index 91247ad..f2f976e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,3 +37,4 @@ serde-big-array = { version = "0.5.1", optional = true } [dev-dependencies] bincode = "1.3.3" serde_json = "1.0.114" +plotters = "0.3.5" \ No newline at end of file diff --git a/examples/plot.rs b/examples/plot.rs new file mode 100644 index 0000000..48db937 --- /dev/null +++ b/examples/plot.rs @@ -0,0 +1,135 @@ +use std::{error::Error, sync::{Arc, Mutex}}; + +use neat::*; +use rand::prelude::*; +use plotters::prelude::*; + +#[derive(RandomlyMutable, DivisionReproduction, Clone)] +struct AgentDNA { + network: NeuralNetworkTopology<2, 1>, +} + +impl Prunable for AgentDNA {} + +impl GenerateRandom for AgentDNA { + fn gen_random(rng: &mut impl Rng) -> Self { + Self { + network: NeuralNetworkTopology::new(0.01, 3, rng), + } + } +} + +fn fitness(g: &AgentDNA) -> f32 { + let network = NeuralNetwork::from(&g.network); + let mut fitness = 0.; + let mut rng = rand::thread_rng(); + + for _ in 0..100 { + let n = rng.gen::() * 10000.; + let base = rng.gen::() * 10.; + let expected = n.log(base); + + let [answer] = network.predict([n, base]); + network.flush_state(); + + fitness += 5. / (answer - expected).abs(); + } + + fitness +} + +struct PlottingNG { + performance_stats: Arc>>, +} + +impl NextgenFn for PlottingNG { + fn next_gen(&self, fitness: Vec<(AgentDNA, f32)>) -> Vec { + let l = fitness.len(); + + let high = fitness[0].1; + + let median = fitness[l / 2].1; + + let low = fitness[l-1].1; + + let mut ps = self.performance_stats.lock().unwrap(); + ps.push(PerformanceStats { high, median, low }); + + division_pruning_nextgen(fitness) + } +} + +struct PerformanceStats { + high: f32, + median: f32, + low: f32, +} + +const OUTPUT_FILE_NAME: &'static str = "fitness-plot.png"; +const GENS: usize = 100; +fn main() -> Result<(), Box> { + let mut rng = rand::thread_rng(); + + let performance_stats = Arc::new(Mutex::new(Vec::with_capacity(GENS))); + let ng = PlottingNG { performance_stats: performance_stats.clone() }; + + let mut sim = GeneticSim::new( + Vec::gen_random(&mut rng, 100), + fitness, + ng, + ); + + println!("Training..."); + + for _ in 0..GENS { + sim.next_generation(); + } + + println!("Training complete, collecting data and building chart..."); + + let root = BitMapBackend::new(OUTPUT_FILE_NAME, (640, 480)).into_drawing_area(); + root.fill(&WHITE)?; + + let mut chart = ChartBuilder::on(&root) + .caption("agent fitness over gens", ("sans-serif", 50).into_font()) + .margin(5) + .x_label_area_size(30) + .y_label_area_size(30) + .build_cartesian_2d(0usize..100, 0f32..200.0)?; + + chart.configure_mesh().draw()?; + + let data: Vec<_> = Arc::into_inner(performance_stats).unwrap().into_inner().unwrap() + .into_iter() + .enumerate() + .collect(); + let highs = data + .iter() + .map(|(i, PerformanceStats { high, .. })| (*i, *high)); + + let medians = data + .iter() + .map(|(i, PerformanceStats { median, .. })| (*i, *median)); + + let lows = data + .iter() + .map(|(i, PerformanceStats { low, .. })| (*i, *low)); + + chart + .draw_series(LineSeries::new(highs, &GREEN))? + .label("high"); + + chart + .draw_series(LineSeries::new(medians, &YELLOW))? + .label("median"); + + chart + .draw_series(LineSeries::new(lows, &RED))? + .label("low"); + + root.present()?; + + println!("Complete"); + + Ok(()) +} \ No newline at end of file From 5cddae7b31b34c8a0f5d37398558f833f3e22560 Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+inflectrix@users.noreply.github.com> Date: Tue, 16 Apr 2024 18:25:49 +0000 Subject: [PATCH 05/27] small changes --- examples/plot.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/plot.rs b/examples/plot.rs index 48db937..59e3a24 100644 --- a/examples/plot.rs +++ b/examples/plot.rs @@ -91,7 +91,7 @@ fn main() -> Result<(), Box> { root.fill(&WHITE)?; let mut chart = ChartBuilder::on(&root) - .caption("agent fitness over gens", ("sans-serif", 50).into_font()) + .caption("agent fitness values per generation", ("sans-serif", 50).into_font()) .margin(5) .x_label_area_size(30) .y_label_area_size(30) @@ -103,6 +103,7 @@ fn main() -> Result<(), Box> { .into_iter() .enumerate() .collect(); + let highs = data .iter() .map(|(i, PerformanceStats { high, .. })| (*i, *high)); From 728cbdeca4e6009399b1167bdd25a535da491d06 Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+inflectrix@users.noreply.github.com> Date: Tue, 16 Apr 2024 18:29:56 +0000 Subject: [PATCH 06/27] more configuration --- examples/plot.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/examples/plot.rs b/examples/plot.rs index 59e3a24..605ed69 100644 --- a/examples/plot.rs +++ b/examples/plot.rs @@ -103,7 +103,7 @@ fn main() -> Result<(), Box> { .into_iter() .enumerate() .collect(); - + let highs = data .iter() .map(|(i, PerformanceStats { high, .. })| (*i, *high)); @@ -128,6 +128,12 @@ fn main() -> Result<(), Box> { .draw_series(LineSeries::new(lows, &RED))? .label("low"); + chart + .configure_series_labels() + .background_style(&WHITE.mix(0.8)) + .border_style(&BLACK) + .draw()?; + root.present()?; println!("Complete"); From 91c3f9f1463096178cf011acc8e8aadd730f7f46 Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+inflectrix@users.noreply.github.com> Date: Wed, 17 Apr 2024 11:31:35 +0000 Subject: [PATCH 07/27] make plotting ng more generic --- examples/plot.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/plot.rs b/examples/plot.rs index 605ed69..6cd555e 100644 --- a/examples/plot.rs +++ b/examples/plot.rs @@ -38,11 +38,12 @@ fn fitness(g: &AgentDNA) -> f32 { fitness } -struct PlottingNG { +struct PlottingNG> { performance_stats: Arc>>, + actual_ng: F, } -impl NextgenFn for PlottingNG { +impl> NextgenFn for PlottingNG { fn next_gen(&self, fitness: Vec<(AgentDNA, f32)>) -> Vec { let l = fitness.len(); @@ -55,7 +56,7 @@ impl NextgenFn for PlottingNG { let mut ps = self.performance_stats.lock().unwrap(); ps.push(PerformanceStats { high, median, low }); - division_pruning_nextgen(fitness) + self.actual_ng.next_gen(fitness) } } @@ -71,7 +72,7 @@ fn main() -> Result<(), Box> { let mut rng = rand::thread_rng(); let performance_stats = Arc::new(Mutex::new(Vec::with_capacity(GENS))); - let ng = PlottingNG { performance_stats: performance_stats.clone() }; + let ng = PlottingNG { performance_stats: performance_stats.clone(), actual_ng: division_pruning_nextgen }; let mut sim = GeneticSim::new( Vec::gen_random(&mut rng, 100), From f6d0df0493d2ec8b8cdc7ce8c978154470c449f5 Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+inflectrix@users.noreply.github.com> Date: Thu, 18 Apr 2024 11:54:27 +0000 Subject: [PATCH 08/27] fix test rayon feature --- src/lib.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index ac569b0..0de19a1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -65,10 +65,16 @@ mod tests { fitness }; + #[cfg(not(feature = "rayon"))] let mut rng = rand::thread_rng(); let mut sim = GeneticSim::new( + #[cfg(not(feature = "rayon"))] Vec::gen_random(&mut rng, 100), + + #[cfg(feature = "rayon")] + Vec::gen_random(100), + fitness, division_pruning_nextgen, ); From cc88ebfc8497ec4583a4fbd43d5ad7e53ef8d9ba Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+inflectrix@users.noreply.github.com> Date: Thu, 18 Apr 2024 11:55:23 +0000 Subject: [PATCH 09/27] cargo fmt --- src/lib.rs | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 0de19a1..0dd0b8c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -30,7 +30,7 @@ mod tests { use rand::prelude::*; #[derive(RandomlyMutable, DivisionReproduction, Clone)] - struct AgentDNA { + struct AgentDNA { network: NeuralNetworkTopology<2, 1>, } @@ -71,10 +71,8 @@ mod tests { let mut sim = GeneticSim::new( #[cfg(not(feature = "rayon"))] Vec::gen_random(&mut rng, 100), - #[cfg(feature = "rayon")] Vec::gen_random(100), - fitness, division_pruning_nextgen, ); @@ -83,13 +81,10 @@ mod tests { sim.next_generation(); } - let mut fits: Vec<_> = sim.genomes - .iter() - .map(fitness) - .collect(); + let mut fits: Vec<_> = sim.genomes.iter().map(fitness).collect(); fits.sort_by(|a, b| a.partial_cmp(&b).unwrap()); dbg!(fits); } -} \ No newline at end of file +} From b95084dd4d615c3a685eebca83df4c16e7133910 Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+inflectrix@users.noreply.github.com> Date: Thu, 18 Apr 2024 14:53:59 +0000 Subject: [PATCH 10/27] create custom activations example --- examples/custom_activation.rs | 92 +++++++++++++++++++++++++++++++++++ src/topology/activation.rs | 4 +- 2 files changed, 94 insertions(+), 2 deletions(-) create mode 100644 examples/custom_activation.rs diff --git a/examples/custom_activation.rs b/examples/custom_activation.rs new file mode 100644 index 0000000..f52882b --- /dev/null +++ b/examples/custom_activation.rs @@ -0,0 +1,92 @@ +//! An example implementation of a custom activation function. + +use neat::*; +use rand::prelude::*; + +#[derive(DivisionReproduction, RandomlyMutable, Clone)] +struct AgentDNA { + network: NeuralNetworkTopology<2, 2>, +} + +impl Prunable for AgentDNA {} + +impl GenerateRandom for AgentDNA { + fn gen_random(rng: &mut impl Rng) -> Self { + Self { + network: NeuralNetworkTopology::new(0.01, 3, rng), + } + } +} + +fn fitness(g: &AgentDNA) -> f32 { + let network: NeuralNetwork<2, 2> = NeuralNetwork::from(&g.network); + let mut fitness = 0.; + let mut rng = rand::thread_rng(); + + for _ in 0..50 { + let n = rng.gen::(); + let n2 = rng.gen::(); + + let expected = if (n + n2) / 2. >= 0.5 { + 0 + } else { + 1 + }; + + let result = network.predict([n, n2]); + network.flush_state(); + + // partial_cmp chance of returning None in this smh + let result = result.iter().max_index(); + + if result == expected { + fitness += 1.; + } else { + fitness -= 1.; + } + } + + fitness +} + +#[cfg(feature = "serde")] +fn serde_nextgen(rewards: Vec<(AgentDNA, f32)>) -> Vec { + let max = rewards + .iter() + .max_by(|(_, ra), (_, rb)| ra.total_cmp(rb)) + .unwrap(); + + let ser = NNTSerde::from(&max.0.network); + let data = serde_json::to_string_pretty(&ser).unwrap(); + std::fs::write("best-agent.json", data).expect("Failed to write to file"); + + division_pruning_nextgen(rewards) +} + +fn main() { + let log_activation = activation_fn!(f32::log10); + register_activation(log_activation); + + #[cfg(not(feature = "rayon"))] + let mut rng = rand::thread_rng(); + + let mut sim = GeneticSim::new( + #[cfg(not(feature = "rayon"))] + Vec::gen_random(&mut rng, 100), + + #[cfg(feature = "rayon")] + Vec::gen_random(100), + + fitness, + + #[cfg(not(feature = "serde"))] + division_pruning_nextgen, + + #[cfg(feature = "serde")] + serde_nextgen, + ); + + for _ in 0..200 { + sim.next_generation(); + } +} \ No newline at end of file diff --git a/src/topology/activation.rs b/src/topology/activation.rs index a711851..5bf9540 100644 --- a/src/topology/activation.rs +++ b/src/topology/activation.rs @@ -15,11 +15,11 @@ use crate::NeuronLocation; #[macro_export] macro_rules! activation_fn { ($F: path) => { - ActivationFn::new(Arc::new($F), ActivationScope::default(), stringify!($F).into()) + ActivationFn::new(std::sync::Arc::new($F), ActivationScope::default(), stringify!($F).into()) }; ($F: path, $S: expr) => { - ActivationFn::new(Arc::new($F), $S, stringify!($F).into()) + ActivationFn::new(std::sync::Arc::new($F), $S, stringify!($F).into()) }; {$($F: path),*} => { From 35868795738bb443c31c09371dc4d76ee56fa6e5 Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+inflectrix@users.noreply.github.com> Date: Thu, 18 Apr 2024 15:01:37 +0000 Subject: [PATCH 11/27] fix opposite high and low --- examples/plot.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/plot.rs b/examples/plot.rs index 6cd555e..4fa4c51 100644 --- a/examples/plot.rs +++ b/examples/plot.rs @@ -47,11 +47,11 @@ impl> NextgenFn for PlottingNG { fn next_gen(&self, fitness: Vec<(AgentDNA, f32)>) -> Vec { let l = fitness.len(); - let high = fitness[0].1; + let high = fitness[l-1].1; let median = fitness[l / 2].1; - let low = fitness[l-1].1; + let low = fitness[0].1; let mut ps = self.performance_stats.lock().unwrap(); ps.push(PerformanceStats { high, median, low }); From 27e972af6f6fcaf885cbc96de7dc052405840897 Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+HyperCodec@users.noreply.github.com> Date: Mon, 6 May 2024 10:19:37 -0400 Subject: [PATCH 12/27] Update Cargo.toml --- Cargo.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 91247ad..96767ef 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,9 +3,9 @@ name = "neat" description = "Crate for working with NEAT in rust" version = "0.5.1" edition = "2021" -authors = ["Inflectrix"] -repository = "https://github.com/inflectrix/neat" -homepage = "https://github.com/inflectrix/neat" +authors = ["HyperCodec"] +repository = "https://github.com/HyperCodec/neat" +homepage = "https://github.com/HyperCodec/neat" readme = "README.md" keywords = ["genetic", "machine-learning", "ai", "algorithm", "evolution"] categories = ["algorithms", "science", "simulation"] From 4b8cef0f7a2d226a6bab62d9e8fde1996978a018 Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+inflectrix@users.noreply.github.com> Date: Wed, 15 May 2024 11:42:36 +0000 Subject: [PATCH 13/27] use svgbackend (now it hangs for some reason) --- .gitignore | 3 ++- examples/plot.rs | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 1b71596..a6e0cb6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /target/ -/.vscode/ \ No newline at end of file +/.vscode/ +best-agent.json \ No newline at end of file diff --git a/examples/plot.rs b/examples/plot.rs index 4fa4c51..2be99c9 100644 --- a/examples/plot.rs +++ b/examples/plot.rs @@ -88,7 +88,7 @@ fn main() -> Result<(), Box> { println!("Training complete, collecting data and building chart..."); - let root = BitMapBackend::new(OUTPUT_FILE_NAME, (640, 480)).into_drawing_area(); + let root = SVGBackend::new(OUTPUT_FILE_NAME, (640, 480)).into_drawing_area(); root.fill(&WHITE)?; let mut chart = ChartBuilder::on(&root) From 6a7090ace3522817ed9979c21efa8fdb42d53112 Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+inflectrix@users.noreply.github.com> Date: Wed, 15 May 2024 13:36:47 +0000 Subject: [PATCH 14/27] fix arc::into_inner failure --- .gitignore | 3 ++- examples/plot.rs | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index a6e0cb6..b2d8069 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ /target/ /.vscode/ -best-agent.json \ No newline at end of file +best-agent.json +fitness-plot.svg \ No newline at end of file diff --git a/examples/plot.rs b/examples/plot.rs index 2be99c9..967b3d0 100644 --- a/examples/plot.rs +++ b/examples/plot.rs @@ -66,7 +66,7 @@ struct PerformanceStats { low: f32, } -const OUTPUT_FILE_NAME: &'static str = "fitness-plot.png"; +const OUTPUT_FILE_NAME: &'static str = "fitness-plot.svg"; const GENS: usize = 100; fn main() -> Result<(), Box> { let mut rng = rand::thread_rng(); @@ -86,6 +86,9 @@ fn main() -> Result<(), Box> { sim.next_generation(); } + // prevent `Arc::into_inner` from failing + drop(sim); + println!("Training complete, collecting data and building chart..."); let root = SVGBackend::new(OUTPUT_FILE_NAME, (640, 480)).into_drawing_area(); From 945ea4a7b1a350d75b9260f2d903ddb1566fc848 Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+inflectrix@users.noreply.github.com> Date: Wed, 15 May 2024 13:43:33 +0000 Subject: [PATCH 15/27] fix data retrieval --- examples/plot.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/plot.rs b/examples/plot.rs index 967b3d0..33c032a 100644 --- a/examples/plot.rs +++ b/examples/plot.rs @@ -44,7 +44,10 @@ struct PlottingNG> { } impl> NextgenFn for PlottingNG { - fn next_gen(&self, fitness: Vec<(AgentDNA, f32)>) -> Vec { + fn next_gen(&self, mut fitness: Vec<(AgentDNA, f32)>) -> Vec { + // it's a bit slower because of sorting twice but I don't want to rewrite the nextgen. + fitness.sort_by(|(_, fa), (_, fb)| fa.partial_cmp(fb).unwrap()); + let l = fitness.len(); let high = fitness[l-1].1; From 0717843bfd615421f20b352f165b8a758822bacc Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+inflectrix@users.noreply.github.com> Date: Wed, 15 May 2024 13:48:58 +0000 Subject: [PATCH 16/27] make compatible with other features --- examples/plot.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/examples/plot.rs b/examples/plot.rs index 33c032a..ab0585a 100644 --- a/examples/plot.rs +++ b/examples/plot.rs @@ -71,14 +71,21 @@ struct PerformanceStats { const OUTPUT_FILE_NAME: &'static str = "fitness-plot.svg"; const GENS: usize = 100; + fn main() -> Result<(), Box> { + #[cfg(not(feature = "rayon"))] let mut rng = rand::thread_rng(); let performance_stats = Arc::new(Mutex::new(Vec::with_capacity(GENS))); let ng = PlottingNG { performance_stats: performance_stats.clone(), actual_ng: division_pruning_nextgen }; let mut sim = GeneticSim::new( + #[cfg(not(feature = "rayon"))] Vec::gen_random(&mut rng, 100), + + #[cfg(feature = "rayon")] + Vec::gen_random(100), + fitness, ng, ); From 6a98fb0d928f8922e6c535f3a74312a89a277d43 Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+inflectrix@users.noreply.github.com> Date: Wed, 15 May 2024 13:49:53 +0000 Subject: [PATCH 17/27] cargo fmt --- examples/plot.rs | 38 +++++++++++++++++++++++--------------- src/lib.rs | 9 +++------ 2 files changed, 26 insertions(+), 21 deletions(-) diff --git a/examples/plot.rs b/examples/plot.rs index ab0585a..2b6a851 100644 --- a/examples/plot.rs +++ b/examples/plot.rs @@ -1,11 +1,14 @@ -use std::{error::Error, sync::{Arc, Mutex}}; +use std::{ + error::Error, + sync::{Arc, Mutex}, +}; use neat::*; -use rand::prelude::*; use plotters::prelude::*; +use rand::prelude::*; #[derive(RandomlyMutable, DivisionReproduction, Clone)] -struct AgentDNA { +struct AgentDNA { network: NeuralNetworkTopology<2, 1>, } @@ -50,7 +53,7 @@ impl> NextgenFn for PlottingNG { let l = fitness.len(); - let high = fitness[l-1].1; + let high = fitness[l - 1].1; let median = fitness[l / 2].1; @@ -77,21 +80,22 @@ fn main() -> Result<(), Box> { let mut rng = rand::thread_rng(); let performance_stats = Arc::new(Mutex::new(Vec::with_capacity(GENS))); - let ng = PlottingNG { performance_stats: performance_stats.clone(), actual_ng: division_pruning_nextgen }; + let ng = PlottingNG { + performance_stats: performance_stats.clone(), + actual_ng: division_pruning_nextgen, + }; let mut sim = GeneticSim::new( #[cfg(not(feature = "rayon"))] Vec::gen_random(&mut rng, 100), - #[cfg(feature = "rayon")] Vec::gen_random(100), - fitness, ng, ); println!("Training..."); - + for _ in 0..GENS { sim.next_generation(); } @@ -105,7 +109,10 @@ fn main() -> Result<(), Box> { root.fill(&WHITE)?; let mut chart = ChartBuilder::on(&root) - .caption("agent fitness values per generation", ("sans-serif", 50).into_font()) + .caption( + "agent fitness values per generation", + ("sans-serif", 50).into_font(), + ) .margin(5) .x_label_area_size(30) .y_label_area_size(30) @@ -113,7 +120,10 @@ fn main() -> Result<(), Box> { chart.configure_mesh().draw()?; - let data: Vec<_> = Arc::into_inner(performance_stats).unwrap().into_inner().unwrap() + let data: Vec<_> = Arc::into_inner(performance_stats) + .unwrap() + .into_inner() + .unwrap() .into_iter() .enumerate() .collect(); @@ -138,9 +148,7 @@ fn main() -> Result<(), Box> { .draw_series(LineSeries::new(medians, &YELLOW))? .label("median"); - chart - .draw_series(LineSeries::new(lows, &RED))? - .label("low"); + chart.draw_series(LineSeries::new(lows, &RED))?.label("low"); chart .configure_series_labels() @@ -151,6 +159,6 @@ fn main() -> Result<(), Box> { root.present()?; println!("Complete"); - + Ok(()) -} \ No newline at end of file +} diff --git a/src/lib.rs b/src/lib.rs index ac569b0..98429d6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -30,7 +30,7 @@ mod tests { use rand::prelude::*; #[derive(RandomlyMutable, DivisionReproduction, Clone)] - struct AgentDNA { + struct AgentDNA { network: NeuralNetworkTopology<2, 1>, } @@ -77,13 +77,10 @@ mod tests { sim.next_generation(); } - let mut fits: Vec<_> = sim.genomes - .iter() - .map(fitness) - .collect(); + let mut fits: Vec<_> = sim.genomes.iter().map(fitness).collect(); fits.sort_by(|a, b| a.partial_cmp(&b).unwrap()); dbg!(fits); } -} \ No newline at end of file +} From 44b7fdbc37992f766d322f22c14e7133eab9a481 Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+inflectrix@users.noreply.github.com> Date: Wed, 15 May 2024 14:24:25 +0000 Subject: [PATCH 18/27] create progress bar for plotting example --- Cargo.lock | 69 ++++++++++++++++++++++++++++++++++++++++++++++++ Cargo.toml | 4 ++- examples/plot.rs | 14 ++++++++-- 3 files changed, 84 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5c98cd1..700cf2f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -100,6 +100,19 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" +[[package]] +name = "console" +version = "0.15.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e1f83fc076bd6dd27517eacdf25fef6c4dfe5f1d7448bafaaf3a26f13b5e4eb" +dependencies = [ + "encode_unicode", + "lazy_static", + "libc", + "unicode-width", + "windows-sys", +] + [[package]] name = "const-cstr" version = "0.3.0" @@ -240,6 +253,12 @@ version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" +[[package]] +name = "encode_unicode" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" + [[package]] name = "fdeflate" version = "0.3.4" @@ -417,6 +436,28 @@ dependencies = [ "png", ] +[[package]] +name = "indicatif" +version = "0.17.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "763a5a8f45087d6bcea4222e7b72c291a054edf80e4ef6efd2a4979878c7bea3" +dependencies = [ + "console", + "instant", + "number_prefix", + "portable-atomic", + "unicode-width", +] + +[[package]] +name = "instant" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c" +dependencies = [ + "cfg-if", +] + [[package]] name = "itoa" version = "1.0.10" @@ -493,6 +534,7 @@ dependencies = [ "bincode", "bitflags 2.5.0", "genetic-rs", + "indicatif", "lazy_static", "plotters", "rand", @@ -511,6 +553,12 @@ dependencies = [ "autocfg", ] +[[package]] +name = "number_prefix" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" + [[package]] name = "once_cell" version = "1.19.0" @@ -601,6 +649,12 @@ dependencies = [ "miniz_oxide", ] +[[package]] +name = "portable-atomic" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -811,6 +865,12 @@ version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +[[package]] +name = "unicode-width" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68f5e5f3158ecfd4b8ff6fe086db7c8467a2dfdac97fe420f2b7c4aa97af66d6" + [[package]] name = "walkdir" version = "2.5.0" @@ -937,6 +997,15 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets", +] + [[package]] name = "windows-targets" version = "0.52.5" diff --git a/Cargo.toml b/Cargo.toml index 8ccd8ca..8305fe4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,7 @@ serde = ["dep:serde", "dep:serde-big-array"] [dependencies] bitflags = "2.5.0" genetic-rs = { version = "0.5.1", features = ["derive"] } + lazy_static = "1.4.0" rand = "0.8.5" rayon = { version = "1.8.1", optional = true } @@ -37,4 +38,5 @@ serde-big-array = { version = "0.5.1", optional = true } [dev-dependencies] bincode = "1.3.3" serde_json = "1.0.114" -plotters = "0.3.5" \ No newline at end of file +plotters = "0.3.5" +indicatif = "0.17.8" diff --git a/examples/plot.rs b/examples/plot.rs index 2b6a851..af48b01 100644 --- a/examples/plot.rs +++ b/examples/plot.rs @@ -6,6 +6,7 @@ use std::{ use neat::*; use plotters::prelude::*; use rand::prelude::*; +use indicatif::{ProgressBar, ProgressStyle}; #[derive(RandomlyMutable, DivisionReproduction, Clone)] struct AgentDNA { @@ -73,7 +74,7 @@ struct PerformanceStats { } const OUTPUT_FILE_NAME: &'static str = "fitness-plot.svg"; -const GENS: usize = 100; +const GENS: usize = 1000; fn main() -> Result<(), Box> { #[cfg(not(feature = "rayon"))] @@ -94,12 +95,21 @@ fn main() -> Result<(), Box> { ng, ); + let pb = ProgressBar::new(GENS as u64) + .with_style(ProgressStyle::with_template("[{elapsed_precise}] {bar:40.cyan/blue} | {msg} {pos}/{len}") + .unwrap()) + .with_message("gen"); + println!("Training..."); for _ in 0..GENS { sim.next_generation(); + + pb.inc(1); } + pb.finish(); + // prevent `Arc::into_inner` from failing drop(sim); @@ -116,7 +126,7 @@ fn main() -> Result<(), Box> { .margin(5) .x_label_area_size(30) .y_label_area_size(30) - .build_cartesian_2d(0usize..100, 0f32..200.0)?; + .build_cartesian_2d(0usize..GENS, 0f32..1000.0)?; chart.configure_mesh().draw()?; From 6d17ec6bf1682f549ec8533aaa01f97ca0956dfe Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+inflectrix@users.noreply.github.com> Date: Wed, 15 May 2024 14:27:12 +0000 Subject: [PATCH 19/27] create progress bar for basic example --- examples/basic.rs | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/examples/basic.rs b/examples/basic.rs index 9ad0419..2aa640b 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -2,6 +2,7 @@ use neat::*; use rand::prelude::*; +use indicatif::{ProgressBar, ProgressStyle}; #[derive(PartialEq, Clone, Debug, DivisionReproduction, RandomlyMutable)] #[cfg_attr(feature = "crossover", derive(CrossoverReproduction))] @@ -103,10 +104,19 @@ fn main() { crossover_pruning_nextgen, ); - for _ in 0..100 { + const GENS: u64 = 1000; + let pb = ProgressBar::new(GENS) + .with_style(ProgressStyle::with_template("[{elapsed_precise}] {bar:40.cyan/blue} | {msg} {pos}/{len}") + .unwrap()) + .with_message("gen"); + + for _ in 0..GENS { sim.next_generation(); + pb.inc(1); } + pb.finish(); + #[cfg(not(feature = "serde"))] let mut fits: Vec<_> = sim.genomes.iter().map(fitness).collect(); From d3a9c409f51c11e7069b414e85069a1d7d7d7a76 Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+inflectrix@users.noreply.github.com> Date: Wed, 15 May 2024 14:28:23 +0000 Subject: [PATCH 20/27] cargo fmt --- examples/basic.rs | 10 +++++++--- examples/plot.rs | 10 +++++++--- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/examples/basic.rs b/examples/basic.rs index 2aa640b..9bbb346 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -1,8 +1,8 @@ //! A basic example of NEAT with this crate. Enable the `crossover` feature for it to use crossover reproduction +use indicatif::{ProgressBar, ProgressStyle}; use neat::*; use rand::prelude::*; -use indicatif::{ProgressBar, ProgressStyle}; #[derive(PartialEq, Clone, Debug, DivisionReproduction, RandomlyMutable)] #[cfg_attr(feature = "crossover", derive(CrossoverReproduction))] @@ -106,8 +106,12 @@ fn main() { const GENS: u64 = 1000; let pb = ProgressBar::new(GENS) - .with_style(ProgressStyle::with_template("[{elapsed_precise}] {bar:40.cyan/blue} | {msg} {pos}/{len}") - .unwrap()) + .with_style( + ProgressStyle::with_template( + "[{elapsed_precise}] {bar:40.cyan/blue} | {msg} {pos}/{len}", + ) + .unwrap(), + ) .with_message("gen"); for _ in 0..GENS { diff --git a/examples/plot.rs b/examples/plot.rs index af48b01..34fb391 100644 --- a/examples/plot.rs +++ b/examples/plot.rs @@ -3,10 +3,10 @@ use std::{ sync::{Arc, Mutex}, }; +use indicatif::{ProgressBar, ProgressStyle}; use neat::*; use plotters::prelude::*; use rand::prelude::*; -use indicatif::{ProgressBar, ProgressStyle}; #[derive(RandomlyMutable, DivisionReproduction, Clone)] struct AgentDNA { @@ -96,8 +96,12 @@ fn main() -> Result<(), Box> { ); let pb = ProgressBar::new(GENS as u64) - .with_style(ProgressStyle::with_template("[{elapsed_precise}] {bar:40.cyan/blue} | {msg} {pos}/{len}") - .unwrap()) + .with_style( + ProgressStyle::with_template( + "[{elapsed_precise}] {bar:40.cyan/blue} | {msg} {pos}/{len}", + ) + .unwrap(), + ) .with_message("gen"); println!("Training..."); From e45908cacd6126997f9d26a4a1dffa99c9f25244 Mon Sep 17 00:00:00 2001 From: Tristan Murphy Date: Fri, 31 May 2024 21:40:06 -0400 Subject: [PATCH 21/27] add logic to prevent duplicated input neurons --- src/topology/mod.rs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/topology/mod.rs b/src/topology/mod.rs index 02ad296..dd246f2 100644 --- a/src/topology/mod.rs +++ b/src/topology/mod.rs @@ -121,6 +121,18 @@ impl NeuralNetworkTopology { return true; } + // check to make sure it isn't duplicate + { + let n = self.get_neuron(to); + let n2 = n.read().unwrap(); + + for (loc, _) in &n2.inputs { + if from == *loc { + return false; + } + } + } + let mut visited = HashSet::new(); self.dfs(from, to, &mut visited) } From 7c31f30f88bcd0221e628daa8b0d0530f5c46523 Mon Sep 17 00:00:00 2001 From: Tristan Murphy Date: Thu, 13 Jun 2024 10:54:22 -0400 Subject: [PATCH 22/27] cargo fmt --- examples/custom_activation.rs | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/examples/custom_activation.rs b/examples/custom_activation.rs index f52882b..bc6aae2 100644 --- a/examples/custom_activation.rs +++ b/examples/custom_activation.rs @@ -27,11 +27,7 @@ fn fitness(g: &AgentDNA) -> f32 { let n = rng.gen::(); let n2 = rng.gen::(); - let expected = if (n + n2) / 2. >= 0.5 { - 0 - } else { - 1 - }; + let expected = if (n + n2) / 2. >= 0.5 { 0 } else { 1 }; let result = network.predict([n, n2]); network.flush_state(); @@ -73,15 +69,11 @@ fn main() { let mut sim = GeneticSim::new( #[cfg(not(feature = "rayon"))] Vec::gen_random(&mut rng, 100), - #[cfg(feature = "rayon")] Vec::gen_random(100), - fitness, - #[cfg(not(feature = "serde"))] division_pruning_nextgen, - #[cfg(feature = "serde")] serde_nextgen, ); @@ -89,4 +81,4 @@ fn main() { for _ in 0..200 { sim.next_generation(); } -} \ No newline at end of file +} From a32bfff0375835cc9f49ad2d4e42513a8cb55a5c Mon Sep 17 00:00:00 2001 From: Tristan Murphy Date: Thu, 13 Jun 2024 11:09:52 -0400 Subject: [PATCH 23/27] change activation function to one that doesn't return NaN --- examples/custom_activation.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/custom_activation.rs b/examples/custom_activation.rs index bc6aae2..7b37c02 100644 --- a/examples/custom_activation.rs +++ b/examples/custom_activation.rs @@ -60,8 +60,8 @@ fn serde_nextgen(rewards: Vec<(AgentDNA, f32)>) -> Vec { } fn main() { - let log_activation = activation_fn!(f32::log10); - register_activation(log_activation); + let sin_activation = activation_fn!(f32::sin); + register_activation(sin_activation); #[cfg(not(feature = "rayon"))] let mut rng = rand::thread_rng(); From 627830a61b25d4859644ca0ed55711f9dbde33a0 Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+HyperCodec@users.noreply.github.com> Date: Tue, 24 Sep 2024 12:08:01 -0400 Subject: [PATCH 24/27] Update CONTRIBUTING.md --- CONTRIBUTING.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 68c1a8c..683c357 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,6 +1,6 @@ Thanks for contributing to this project. -To get started, check out the [issues page](https://github.com/inflectrix/neat). You can either find a feature/fix from there or start a new issue, then begin implementing it in your own fork of this repo. +To get started, check out the [issues page](https://github.com/hypercodec/neat). You can either find a feature/fix from there or start a new issue, then begin implementing it in your own fork of this repo. -Once you are done making the changes you'd like the make, start a pull request to the [dev](https://github.com/inflectrix/neat/tree/dev) branch. State your changes and request a review. After all branch rules have been satisfied, someone with management permissions on this repository will merge it. +Once you are done making the changes you'd like the make, start a pull request to the [dev](https://github.com/hypercodec/neat/tree/dev) branch. State your changes and request a review. After all branch rules have been satisfied, someone with management permissions on this repository will merge it. From 1978058ed0a3b0adf25e3deb688e4e4cbfbf5733 Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+HyperCodec@users.noreply.github.com> Date: Tue, 4 Feb 2025 14:02:23 +0000 Subject: [PATCH 25/27] Merge branch 'rewrite' into dev --- .github/workflows/ci-cd.yml | 2 +- CONTRIBUTING.md | 3 +- Cargo.lock | 885 ++----------------------------- Cargo.toml | 24 +- README.md | 97 +--- examples/basic.rs | 143 +---- examples/extra_dna.rs | 3 + src/{topology => }/activation.rs | 101 ++-- src/activation/builtin.rs | 14 + src/lib.rs | 93 +--- src/neuralnet.rs | 856 ++++++++++++++++++++++++++++++ src/runnable.rs | 300 ----------- src/tests.rs | 179 +++++++ src/topology/mod.rs | 638 ---------------------- src/topology/nnt_serde.rs | 71 --- 15 files changed, 1165 insertions(+), 2244 deletions(-) create mode 100644 examples/extra_dna.rs rename src/{topology => }/activation.rs (79%) create mode 100644 src/activation/builtin.rs create mode 100644 src/neuralnet.rs delete mode 100644 src/runnable.rs create mode 100644 src/tests.rs delete mode 100644 src/topology/nnt_serde.rs diff --git a/.github/workflows/ci-cd.yml b/.github/workflows/ci-cd.yml index fef7ca6..fd849f4 100644 --- a/.github/workflows/ci-cd.yml +++ b/.github/workflows/ci-cd.yml @@ -2,7 +2,7 @@ name: CI-CD on: push: - branches: [main] + branches: [main, dev] pull_request: jobs: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 683c357..8ede4a4 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -2,5 +2,6 @@ Thanks for contributing to this project. To get started, check out the [issues page](https://github.com/hypercodec/neat). You can either find a feature/fix from there or start a new issue, then begin implementing it in your own fork of this repo. -Once you are done making the changes you'd like the make, start a pull request to the [dev](https://github.com/hypercodec/neat/tree/dev) branch. State your changes and request a review. After all branch rules have been satisfied, someone with management permissions on this repository will merge it. +Once you are done making the changes you'd like the make, start a pull request to the [dev](https://github.com/hypercodec/neat/tree/dev) branch. State your changes and request a review. After all branch rules have been satisfied and the pull request has a valid reason, someone with management permissions on this repository will merge it. +You could also make a draft PR while implementing your features if you want feedback or discussion before finalizing your changes. \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 700cf2f..0a9a53b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,78 +1,18 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] -name = "adler" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" - -[[package]] -name = "android-tzdata" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" - -[[package]] -name = "android_system_properties" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" -dependencies = [ - "libc", -] - -[[package]] -name = "autocfg" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80" - -[[package]] -name = "bincode" -version = "1.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" -dependencies = [ - "serde", -] - -[[package]] -name = "bitflags" -version = "1.3.2" +name = "atomic_float" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +checksum = "628d228f918ac3b82fe590352cc719d30664a0c13ca3a60266fe02c7132d480a" [[package]] name = "bitflags" -version = "2.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" - -[[package]] -name = "bumpalo" -version = "3.16.0" +version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" - -[[package]] -name = "bytemuck" -version = "1.15.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d6d68c57235a3a081186990eca2867354726650f42f7516ca50c28d6281fd15" - -[[package]] -name = "byteorder" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" - -[[package]] -name = "cc" -version = "1.0.94" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17f6e324229dc011159fcc089755d1e2e216a90d43a7dea6853ca740b84f35e7" +checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36" [[package]] name = "cfg-if" @@ -80,106 +20,6 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" -[[package]] -name = "chrono" -version = "0.4.38" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" -dependencies = [ - "android-tzdata", - "iana-time-zone", - "js-sys", - "num-traits", - "wasm-bindgen", - "windows-targets", -] - -[[package]] -name = "color_quant" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" - -[[package]] -name = "console" -version = "0.15.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e1f83fc076bd6dd27517eacdf25fef6c4dfe5f1d7448bafaaf3a26f13b5e4eb" -dependencies = [ - "encode_unicode", - "lazy_static", - "libc", - "unicode-width", - "windows-sys", -] - -[[package]] -name = "const-cstr" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed3d0b5ff30645a68f35ece8cea4556ca14ef8a1651455f789a099a0513532a6" - -[[package]] -name = "core-foundation" -version = "0.9.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" -dependencies = [ - "core-foundation-sys", - "libc", -] - -[[package]] -name = "core-foundation-sys" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" - -[[package]] -name = "core-graphics" -version = "0.22.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2581bbab3b8ffc6fcbd550bf46c355135d16e9ff2a6ea032ad6b9bf1d7efe4fb" -dependencies = [ - "bitflags 1.3.2", - "core-foundation", - "core-graphics-types", - "foreign-types", - "libc", -] - -[[package]] -name = "core-graphics-types" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45390e6114f68f718cc7a830514a96f903cccd70d02a8f6d9f643ac4ba45afaf" -dependencies = [ - "bitflags 1.3.2", - "core-foundation", - "libc", -] - -[[package]] -name = "core-text" -version = "19.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99d74ada66e07c1cefa18f8abfba765b486f250de2e4a999e5727fc0dd4b4a25" -dependencies = [ - "core-foundation", - "core-graphics", - "foreign-types", - "libc", -] - -[[package]] -name = "crc32fast" -version = "1.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3855a8a784b474f333699ef2bbca9db2c4a1f6d9088a90a2d25b1eb53111eaa" -dependencies = [ - "cfg-if", -] - [[package]] name = "crossbeam-deque" version = "0.8.5" @@ -205,151 +45,17 @@ version = "0.8.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" -[[package]] -name = "dirs-next" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b98cf8ebf19c3d1b223e151f99a4f9f0690dca41414773390fc824184ac833e1" -dependencies = [ - "cfg-if", - "dirs-sys-next", -] - -[[package]] -name = "dirs-sys-next" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ebda144c4fe02d1f7ea1a7d9641b6fc6b580adcfa024ae48797ecdeb6825b4d" -dependencies = [ - "libc", - "redox_users", - "winapi", -] - -[[package]] -name = "dlib" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "330c60081dcc4c72131f8eb70510f1ac07223e5d4163db481a04a0befcffa412" -dependencies = [ - "libloading", -] - -[[package]] -name = "dwrote" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "439a1c2ba5611ad3ed731280541d36d2e9c4ac5e7fb818a27b604bdc5a6aa65b" -dependencies = [ - "lazy_static", - "libc", - "winapi", - "wio", -] - [[package]] name = "either" version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" -[[package]] -name = "encode_unicode" -version = "0.3.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" - -[[package]] -name = "fdeflate" -version = "0.3.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f9bfee30e4dedf0ab8b422f03af778d9612b63f502710fc500a334ebe2de645" -dependencies = [ - "simd-adler32", -] - -[[package]] -name = "flate2" -version = "1.0.28" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46303f565772937ffe1d394a4fac6f411c6013172fadde9dcdb1e147a086940e" -dependencies = [ - "crc32fast", - "miniz_oxide", -] - -[[package]] -name = "float-ord" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7bad48618fdb549078c333a7a8528acb57af271d0433bdecd523eb620628364e" - -[[package]] -name = "font-kit" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21fe28504d371085fae9ac7a3450f0b289ab71e07c8e57baa3fb68b9e57d6ce5" -dependencies = [ - "bitflags 1.3.2", - "byteorder", - "core-foundation", - "core-graphics", - "core-text", - "dirs-next", - "dwrote", - "float-ord", - "freetype", - "lazy_static", - "libc", - "log", - "pathfinder_geometry", - "pathfinder_simd", - "walkdir", - "winapi", - "yeslogic-fontconfig-sys", -] - -[[package]] -name = "foreign-types" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" -dependencies = [ - "foreign-types-shared", -] - -[[package]] -name = "foreign-types-shared" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" - -[[package]] -name = "freetype" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "efc8599a3078adf8edeb86c71e9f8fa7d88af5ca31e806a867756081f90f5d83" -dependencies = [ - "freetype-sys", - "libc", -] - -[[package]] -name = "freetype-sys" -version = "0.19.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "66ee28c39a43d89fbed8b4798fb4ba56722cfd2b5af81f9326c27614ba88ecd5" -dependencies = [ - "cc", - "libc", - "pkg-config", -] - [[package]] name = "genetic-rs" -version = "0.5.1" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b94601f3db2fb341f71a4470134eb1f71d39f54c2fe264122698eda67cd1c91b" +checksum = "a68bb62a836f6ea3261d77cfec4012316e206f53e7d0eab519f5f3630e86001f" dependencies = [ "genetic-rs-common", "genetic-rs-macros", @@ -357,9 +63,9 @@ dependencies = [ [[package]] name = "genetic-rs-common" -version = "0.5.1" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f41b0e3f6ccb66a00e7fc9170d4e02b1ae80c85f03c67b76b067b3637fd314a" +checksum = "3be7aaffd4e4dc82d11819d40794f089c37d02595a401f229ed2877d1a4c401d" dependencies = [ "rand", "rayon", @@ -368,9 +74,9 @@ dependencies = [ [[package]] name = "genetic-rs-macros" -version = "0.5.1" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d5ec3b9e69a6836bb0f0c8fa6972e6322e0b49108f7b3ed40769feb452c120a" +checksum = "4e73b1f36ea3e799232e1a3141a2765fa6ee9ed7bb3fed96ccfb3bf272d1832e" dependencies = [ "genetic-rs-common", "proc-macro2", @@ -389,272 +95,45 @@ dependencies = [ "wasi", ] -[[package]] -name = "gif" -version = "0.12.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80792593675e051cf94a4b111980da2ba60d4a83e43e0048c5693baab3977045" -dependencies = [ - "color_quant", - "weezl", -] - -[[package]] -name = "iana-time-zone" -version = "0.1.60" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" -dependencies = [ - "android_system_properties", - "core-foundation-sys", - "iana-time-zone-haiku", - "js-sys", - "wasm-bindgen", - "windows-core", -] - -[[package]] -name = "iana-time-zone-haiku" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" -dependencies = [ - "cc", -] - -[[package]] -name = "image" -version = "0.24.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5690139d2f55868e080017335e4b94cb7414274c74f1669c84fb5feba2c9f69d" -dependencies = [ - "bytemuck", - "byteorder", - "color_quant", - "jpeg-decoder", - "num-traits", - "png", -] - -[[package]] -name = "indicatif" -version = "0.17.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "763a5a8f45087d6bcea4222e7b72c291a054edf80e4ef6efd2a4979878c7bea3" -dependencies = [ - "console", - "instant", - "number_prefix", - "portable-atomic", - "unicode-width", -] - -[[package]] -name = "instant" -version = "0.1.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c" -dependencies = [ - "cfg-if", -] - [[package]] name = "itoa" -version = "1.0.10" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" - -[[package]] -name = "jpeg-decoder" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f5d4a7da358eff58addd2877a45865158f0d78c911d43a5784ceb7bbf52833b0" - -[[package]] -name = "js-sys" -version = "0.3.69" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" -dependencies = [ - "wasm-bindgen", -] +checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674" [[package]] name = "lazy_static" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "libc" -version = "0.2.153" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" - -[[package]] -name = "libloading" -version = "0.8.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c2a198fb6b0eada2a8df47933734e6d35d350665a33a3593d7164fa52c75c19" -dependencies = [ - "cfg-if", - "windows-targets", -] - -[[package]] -name = "libredox" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" -dependencies = [ - "bitflags 2.5.0", - "libc", -] - -[[package]] -name = "log" -version = "0.4.21" +version = "0.2.169" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" +checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" [[package]] -name = "miniz_oxide" -version = "0.7.2" +name = "memchr" +version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d811f3e15f28568be3407c8e7fdb6514c1cda3cb30683f15b6a1a1dc4ea14a7" -dependencies = [ - "adler", - "simd-adler32", -] +checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] name = "neat" version = "0.5.1" dependencies = [ - "bincode", - "bitflags 2.5.0", + "atomic_float", + "bitflags", "genetic-rs", - "indicatif", "lazy_static", - "plotters", - "rand", "rayon", + "replace_with", "serde", "serde-big-array", "serde_json", ] -[[package]] -name = "num-traits" -version = "0.2.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da0df0e5185db44f69b44f26786fe401b6c293d1907744beaa7fa62b2e5a517a" -dependencies = [ - "autocfg", -] - -[[package]] -name = "number_prefix" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" - -[[package]] -name = "once_cell" -version = "1.19.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" - -[[package]] -name = "pathfinder_geometry" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b7b7e7b4ea703700ce73ebf128e1450eb69c3a8329199ffbfb9b2a0418e5ad3" -dependencies = [ - "log", - "pathfinder_simd", -] - -[[package]] -name = "pathfinder_simd" -version = "0.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ebf45976c56919841273f2a0fc684c28437e2f304e264557d9c72be5d5a718be" -dependencies = [ - "rustc_version", -] - -[[package]] -name = "pkg-config" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" - -[[package]] -name = "plotters" -version = "0.3.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2c224ba00d7cadd4d5c660deaf2098e5e80e07846537c51f9cfa4be50c1fd45" -dependencies = [ - "chrono", - "font-kit", - "image", - "lazy_static", - "num-traits", - "pathfinder_geometry", - "plotters-backend", - "plotters-bitmap", - "plotters-svg", - "ttf-parser", - "wasm-bindgen", - "web-sys", -] - -[[package]] -name = "plotters-backend" -version = "0.3.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e76628b4d3a7581389a35d5b6e2139607ad7c75b17aed325f210aa91f4a9609" - -[[package]] -name = "plotters-bitmap" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0cebbe1f70205299abc69e8b295035bb52a6a70ee35474ad10011f0a4efb8543" -dependencies = [ - "gif", - "image", - "plotters-backend", -] - -[[package]] -name = "plotters-svg" -version = "0.3.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38f6d39893cca0701371e3c27294f09797214b86f1fb951b89ade8ec04e2abab" -dependencies = [ - "plotters-backend", -] - -[[package]] -name = "png" -version = "0.17.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06e4b0d3d1312775e782c86c91a111aa1f910cbb65e1337f9975b5f9a554b5e1" -dependencies = [ - "bitflags 1.3.2", - "crc32fast", - "fdeflate", - "flate2", - "miniz_oxide", -] - -[[package]] -name = "portable-atomic" -version = "1.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" - [[package]] name = "ppv-lite86" version = "0.2.17" @@ -663,9 +142,9 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro2" -version = "1.0.78" +version = "1.0.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae" +checksum = "307e3004becf10f5a6e0d59d20f3cd28231b0e0827a96cd3e0ce6d14bc1e4bb3" dependencies = [ "unicode-ident", ] @@ -711,9 +190,9 @@ dependencies = [ [[package]] name = "rayon" -version = "1.8.1" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa7237101a77a10773db45d62004a272517633fbcc3df19d96455ede1122e051" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" dependencies = [ "either", "rayon-core", @@ -729,58 +208,23 @@ dependencies = [ "crossbeam-utils", ] -[[package]] -name = "redox_users" -version = "0.4.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd283d9651eeda4b2a83a43c1c91b266c40fd76ecd39a50a8c630ae69dc72891" -dependencies = [ - "getrandom", - "libredox", - "thiserror", -] - [[package]] name = "replace_with" version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3a8614ee435691de62bcffcf4a66d91b3594bf1428a5722e79103249a095690" -[[package]] -name = "rustc_version" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" -dependencies = [ - "semver", -] - [[package]] name = "ryu" -version = "1.0.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1" - -[[package]] -name = "same-file" -version = "1.0.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" -dependencies = [ - "winapi-util", -] - -[[package]] -name = "semver" -version = "1.0.22" +version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92d43fe69e652f3df9bdc2b85b2854a0825b86e4fb76bc44d945137d053639ca" +checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd" [[package]] name = "serde" -version = "1.0.197" +version = "1.0.217" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fb1c873e1b9b056a4dc4c0c198b24c3ffa059243875552b2bd0933b1aee4ce2" +checksum = "02fc4265df13d6fa1d00ecff087228cc0a2b5f3c0e87e258d8b94a156e984c70" dependencies = [ "serde_derive", ] @@ -796,9 +240,9 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.197" +version = "1.0.217" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" +checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" dependencies = [ "proc-macro2", "quote", @@ -807,286 +251,35 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.114" +version = "1.0.138" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5f09b1bd632ef549eaa9f60a1f8de742bdbc698e6cee2095fc84dde5f549ae0" +checksum = "d434192e7da787e94a6ea7e9670b26a036d0ca41e0b7efb2676dd32bae872949" dependencies = [ "itoa", + "memchr", "ryu", "serde", ] -[[package]] -name = "simd-adler32" -version = "0.3.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" - [[package]] name = "syn" -version = "2.0.51" +version = "2.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ab617d94515e94ae53b8406c628598680aa0c9587474ecbe58188f7b345d66c" +checksum = "44d46482f1c1c87acd84dea20c1bf5ebff4c757009ed6bf19cfd36fb10e92c4e" dependencies = [ "proc-macro2", "quote", "unicode-ident", ] -[[package]] -name = "thiserror" -version = "1.0.58" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03468839009160513471e86a034bb2c5c0e4baae3b43f79ffc55c4a5427b3297" -dependencies = [ - "thiserror-impl", -] - -[[package]] -name = "thiserror-impl" -version = "1.0.58" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "ttf-parser" -version = "0.17.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "375812fa44dab6df41c195cd2f7fecb488f6c09fbaafb62807488cefab642bff" - [[package]] name = "unicode-ident" version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" -[[package]] -name = "unicode-width" -version = "0.1.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68f5e5f3158ecfd4b8ff6fe086db7c8467a2dfdac97fe420f2b7c4aa97af66d6" - -[[package]] -name = "walkdir" -version = "2.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" -dependencies = [ - "same-file", - "winapi-util", -] - [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" - -[[package]] -name = "wasm-bindgen" -version = "0.2.92" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" -dependencies = [ - "cfg-if", - "wasm-bindgen-macro", -] - -[[package]] -name = "wasm-bindgen-backend" -version = "0.2.92" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" -dependencies = [ - "bumpalo", - "log", - "once_cell", - "proc-macro2", - "quote", - "syn", - "wasm-bindgen-shared", -] - -[[package]] -name = "wasm-bindgen-macro" -version = "0.2.92" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" -dependencies = [ - "quote", - "wasm-bindgen-macro-support", -] - -[[package]] -name = "wasm-bindgen-macro-support" -version = "0.2.92" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" -dependencies = [ - "proc-macro2", - "quote", - "syn", - "wasm-bindgen-backend", - "wasm-bindgen-shared", -] - -[[package]] -name = "wasm-bindgen-shared" -version = "0.2.92" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" - -[[package]] -name = "web-sys" -version = "0.3.69" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77afa9a11836342370f4817622a2f0f418b134426d91a82dfb48f532d2ec13ef" -dependencies = [ - "js-sys", - "wasm-bindgen", -] - -[[package]] -name = "weezl" -version = "0.1.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082" - -[[package]] -name = "winapi" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" -dependencies = [ - "winapi-i686-pc-windows-gnu", - "winapi-x86_64-pc-windows-gnu", -] - -[[package]] -name = "winapi-i686-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" - -[[package]] -name = "winapi-util" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f29e6f9198ba0d26b4c9f07dbe6f9ed633e1f3d5b8b414090084349e46a52596" -dependencies = [ - "winapi", -] - -[[package]] -name = "winapi-x86_64-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" - -[[package]] -name = "windows-core" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" -dependencies = [ - "windows-targets", -] - -[[package]] -name = "windows-sys" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" -dependencies = [ - "windows-targets", -] - -[[package]] -name = "windows-targets" -version = "0.52.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" -dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_gnullvm", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", -] - -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.52.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" - -[[package]] -name = "windows_aarch64_msvc" -version = "0.52.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" - -[[package]] -name = "windows_i686_gnu" -version = "0.52.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" - -[[package]] -name = "windows_i686_gnullvm" -version = "0.52.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" - -[[package]] -name = "windows_i686_msvc" -version = "0.52.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" - -[[package]] -name = "windows_x86_64_gnu" -version = "0.52.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" - -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.52.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" - -[[package]] -name = "windows_x86_64_msvc" -version = "0.52.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" - -[[package]] -name = "wio" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d129932f4644ac2396cb456385cbf9e63b5b30c6e8dc4820bdca4eb082037a5" -dependencies = [ - "winapi", -] - -[[package]] -name = "yeslogic-fontconfig-sys" -version = "3.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2bbd69036d397ebbff671b1b8e4d918610c181c5a16073b96f984a38d08c386" -dependencies = [ - "const-cstr", - "dlib", - "once_cell", - "pkg-config", -] diff --git a/Cargo.toml b/Cargo.toml index 8305fe4..4b26e0f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,25 +18,19 @@ rustdoc-args = ["--cfg", "docsrs"] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -default = ["max-index"] -crossover = ["genetic-rs/crossover"] -rayon = ["genetic-rs/rayon", "dep:rayon"] -max-index = [] +default = [] serde = ["dep:serde", "dep:serde-big-array"] [dependencies] -bitflags = "2.5.0" -genetic-rs = { version = "0.5.1", features = ["derive"] } - -lazy_static = "1.4.0" -rand = "0.8.5" -rayon = { version = "1.8.1", optional = true } -serde = { version = "1.0.197", features = ["derive"], optional = true } +atomic_float = "1.1.0" +bitflags = "2.8.0" +genetic-rs = { version = "0.5.4", features = ["rayon", "derive"] } +lazy_static = "1.5.0" +rayon = "1.10.0" +replace_with = "0.1.7" +serde = { version = "1.0.217", features = ["derive"], optional = true } serde-big-array = { version = "0.5.1", optional = true } [dev-dependencies] -bincode = "1.3.3" -serde_json = "1.0.114" -plotters = "0.3.5" -indicatif = "0.17.8" +serde_json = "1.0.138" \ No newline at end of file diff --git a/README.md b/README.md index 1a97864..4e9828b 100644 --- a/README.md +++ b/README.md @@ -1,104 +1,17 @@ # neat -[github](https://github.com/inflectrix/neat) +[github](https://github.com/hypercodec/neat) [crates.io](https://crates.io/crates/neat) [docs.rs](https://docs.rs/neat) Implementation of the NEAT algorithm using `genetic-rs`. ### Features -- rayon - Uses parallelization on the `NeuralNetwork` struct and adds the `rayon` feature to the `genetic-rs` re-export. -- serde - Adds the NNTSerde struct and allows for serialization of `NeuralNetworkTopology` -- crossover - Implements the `CrossoverReproduction` trait on `NeuralNetworkTopology` and adds the `crossover` feature to the `genetic-rs` re-export. +- serde - Implements `Serialize` and `Deserialize` on most of the types in this crate. -*Do you like this repo and want to support it? If so, leave a ⭐* +*Do you like this crate and want to support it? If so, leave a ⭐* -### How To Use -When working with this crate, you'll want to use the `NeuralNetworkTopology` struct in your agent's DNA and -the use `NeuralNetwork::from` when you finally want to test its performance. The `genetic-rs` crate is also re-exported with the rest of this crate. - -Here's an example of how one might use this crate: -```rust -use neat::*; - -#[derive(Clone, RandomlyMutable, DivisionReproduction)] -struct MyAgentDNA { - network: NeuralNetworkTopology<1, 2>, -} - -impl GenerateRandom for MyAgentDNA { - fn gen_random(rng: &mut impl rand::Rng) -> Self { - Self { - network: NeuralNetworkTopology::new(0.01, 3, rng), - } - } -} - -struct MyAgent { - network: NeuralNetwork<1, 2>, - // ... other state -} - -impl From<&MyAgentDNA> for MyAgent { - fn from(value: &MyAgentDNA) -> Self { - Self { - network: NeuralNetwork::from(&value.network), - } - } -} - -fn fitness(dna: &MyAgentDNA) -> f32 { - // agent will simply try to predict whether a number is greater than 0.5 - let mut agent = MyAgent::from(dna); - let mut rng = rand::thread_rng(); - let mut fitness = 0; - - // use repeated tests to avoid situational bias and some local maximums, overall providing more accurate score - for _ in 0..10 { - let n = rng.gen::(); - let above = n > 0.5; - - let res = agent.network.predict([n]); - agent.network.flush_state(); - - let resi = res.iter().max_index(); - - if resi == 0 ^ above { - // agent did not guess correctly, punish slightly (too much will hinder exploration) - fitness -= 0.5; - - continue; - } - - // agent guessed correctly, they become more fit. - fitness += 3.; - } - - fitness -} - -fn main() { - let mut rng = rand::thread_rng(); - - let mut sim = GeneticSim::new( - Vec::gen_random(&mut rng, 100), - fitness, - division_pruning_nextgen, - ); - - // simulate 100 generations - for _ in 0..100 { - sim.next_generation(); - } - - // display fitness results - let fits: Vec<_> = sim.entities - .iter() - .map(fitness) - .collect(); - - dbg!(&fits, fits.iter().max()); -} -``` +# How To Use +TODO ### License This crate falls under the `MIT` license diff --git a/examples/basic.rs b/examples/basic.rs index 9bbb346..85f58cb 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -1,144 +1,3 @@ -//! A basic example of NEAT with this crate. Enable the `crossover` feature for it to use crossover reproduction - -use indicatif::{ProgressBar, ProgressStyle}; -use neat::*; -use rand::prelude::*; - -#[derive(PartialEq, Clone, Debug, DivisionReproduction, RandomlyMutable)] -#[cfg_attr(feature = "crossover", derive(CrossoverReproduction))] -struct AgentDNA { - network: NeuralNetworkTopology<2, 4>, -} - -impl Prunable for AgentDNA {} - -impl GenerateRandom for AgentDNA { - fn gen_random(rng: &mut impl rand::Rng) -> Self { - Self { - network: NeuralNetworkTopology::new(0.01, 3, rng), - } - } -} - -#[derive(Debug)] -struct Agent { - network: NeuralNetwork<2, 4>, -} - -impl From<&AgentDNA> for Agent { - fn from(value: &AgentDNA) -> Self { - Self { - network: (&value.network).into(), - } - } -} - -fn fitness(dna: &AgentDNA) -> f32 { - let agent = Agent::from(dna); - - let mut fitness = 0.; - let mut rng = rand::thread_rng(); - - for _ in 0..10 { - // 10 games - - // set up game - let mut agent_pos: (i32, i32) = (rng.gen_range(0..10), rng.gen_range(0..10)); - let mut food_pos: (i32, i32) = (rng.gen_range(0..10), rng.gen_range(0..10)); - - while food_pos == agent_pos { - food_pos = (rng.gen_range(0..10), rng.gen_range(0..10)); - } - - let mut step = 0; - - loop { - // perform actions in game - let action = agent.network.predict([ - (food_pos.0 - agent_pos.0) as f32, - (food_pos.1 - agent_pos.1) as f32, - ]); - let action = action.iter().max_index(); - - match action { - 0 => agent_pos.0 += 1, - 1 => agent_pos.0 -= 1, - 2 => agent_pos.1 += 1, - _ => agent_pos.1 -= 1, - } - - step += 1; - - if agent_pos == food_pos { - fitness += 10.; - break; // new game - } else { - // lose fitness for being slow and far away - fitness -= - (food_pos.0 - agent_pos.0 + food_pos.1 - agent_pos.1).abs() as f32 * 0.001; - } - - // 50 steps per game - if step == 50 { - break; - } - } - } - - fitness -} - fn main() { - #[cfg(not(feature = "rayon"))] - let mut rng = rand::thread_rng(); - - let mut sim = GeneticSim::new( - #[cfg(not(feature = "rayon"))] - Vec::gen_random(&mut rng, 100), - #[cfg(feature = "rayon")] - Vec::gen_random(100), - fitness, - #[cfg(not(feature = "crossover"))] - division_pruning_nextgen, - #[cfg(feature = "crossover")] - crossover_pruning_nextgen, - ); - - const GENS: u64 = 1000; - let pb = ProgressBar::new(GENS) - .with_style( - ProgressStyle::with_template( - "[{elapsed_precise}] {bar:40.cyan/blue} | {msg} {pos}/{len}", - ) - .unwrap(), - ) - .with_message("gen"); - - for _ in 0..GENS { - sim.next_generation(); - pb.inc(1); - } - - pb.finish(); - - #[cfg(not(feature = "serde"))] - let mut fits: Vec<_> = sim.genomes.iter().map(fitness).collect(); - - #[cfg(feature = "serde")] - let mut fits: Vec<_> = sim.genomes.iter().map(|e| (e, fitness(e))).collect(); - - #[cfg(not(feature = "serde"))] - fits.sort_by(|a, b| a.partial_cmp(&b).unwrap()); - - #[cfg(feature = "serde")] - fits.sort_by(|(_, a), (_, b)| a.partial_cmp(&b).unwrap()); - - dbg!(&fits); - - #[cfg(feature = "serde")] - { - let intermediate = NNTSerde::from(&fits[0].0.network); - let serialized = serde_json::to_string(&intermediate).unwrap(); - println!("{}", serialized); - } + todo!("use NeuralNetwork as the entire DNA"); } diff --git a/examples/extra_dna.rs b/examples/extra_dna.rs new file mode 100644 index 0000000..038709f --- /dev/null +++ b/examples/extra_dna.rs @@ -0,0 +1,3 @@ +fn main() { + todo!("use AgentDNA with additional params") +} diff --git a/src/topology/activation.rs b/src/activation.rs similarity index 79% rename from src/topology/activation.rs rename to src/activation.rs index 5bf9540..af9f74e 100644 --- a/src/topology/activation.rs +++ b/src/activation.rs @@ -1,7 +1,12 @@ +/// Contains some builtin activation functions ([`sigmoid`], [`relu`], etc.) +pub mod builtin; + +use bitflags::bitflags; +use builtin::*; + #[cfg(feature = "serde")] use serde::{Deserialize, Deserializer, Serialize, Serializer}; -use bitflags::bitflags; use lazy_static::lazy_static; use std::{ collections::HashMap, @@ -15,7 +20,7 @@ use crate::NeuronLocation; #[macro_export] macro_rules! activation_fn { ($F: path) => { - ActivationFn::new(std::sync::Arc::new($F), ActivationScope::default(), stringify!($F).into()) + ActivationFn::new(std::sync::Arc::new($F), NeuronScope::default(), stringify!($F).into()) }; ($F: path, $S: expr) => { @@ -73,11 +78,11 @@ impl ActivationRegistry { } /// Gets all activation functions that are valid for a scope. - pub fn activations_in_scope(&self, scope: ActivationScope) -> Vec { + pub fn activations_in_scope(&self, scope: NeuronScope) -> Vec { let acts = self.activations(); acts.into_iter() - .filter(|a| a.scope != ActivationScope::NONE && a.scope.contains(scope)) + .filter(|a| !a.scope.contains(NeuronScope::NONE) && a.scope.contains(scope)) .collect() } } @@ -88,51 +93,18 @@ impl Default for ActivationRegistry { fns: HashMap::new(), }; + // TODO add a way to disable this s.batch_register(activation_fn! { - sigmoid => ActivationScope::HIDDEN | ActivationScope::OUTPUT, - relu => ActivationScope::HIDDEN | ActivationScope::OUTPUT, - linear_activation => ActivationScope::INPUT | ActivationScope::HIDDEN | ActivationScope::OUTPUT, - f32::tanh => ActivationScope::HIDDEN | ActivationScope::OUTPUT + sigmoid => NeuronScope::HIDDEN | NeuronScope::OUTPUT, + relu => NeuronScope::HIDDEN | NeuronScope::OUTPUT, + linear_activation => NeuronScope::INPUT | NeuronScope::HIDDEN | NeuronScope::OUTPUT, + f32::tanh => NeuronScope::HIDDEN | NeuronScope::OUTPUT }); s } } -bitflags! { - /// Specifies where an activation function can occur - #[derive(Copy, Clone, Debug, Eq, PartialEq)] - pub struct ActivationScope: u8 { - /// Whether the activation can be applied to the input layer. - const INPUT = 0b001; - - /// Whether the activation can be applied to the hidden layer. - const HIDDEN = 0b010; - - /// Whether the activation can be applied to the output layer. - const OUTPUT = 0b100; - - /// The activation function will not be randomly placed anywhere - const NONE = 0b000; - } -} - -impl Default for ActivationScope { - fn default() -> Self { - Self::HIDDEN - } -} - -impl From<&NeuronLocation> for ActivationScope { - fn from(value: &NeuronLocation) -> Self { - match value { - NeuronLocation::Input(_) => Self::INPUT, - NeuronLocation::Hidden(_) => Self::HIDDEN, - NeuronLocation::Output(_) => Self::OUTPUT, - } - } -} - /// A trait that represents an activation method. pub trait Activation { /// The activation function. @@ -152,17 +124,13 @@ pub struct ActivationFn { pub func: Arc, /// The scope defining where the activation function can appear. - pub scope: ActivationScope, + pub scope: NeuronScope, pub(crate) name: String, } impl ActivationFn { /// Creates a new ActivationFn object. - pub fn new( - func: Arc, - scope: ActivationScope, - name: String, - ) -> Self { + pub fn new(func: Arc, scope: NeuronScope, name: String) -> Self { Self { func, name, scope } } } @@ -206,17 +174,36 @@ impl<'a> Deserialize<'a> for ActivationFn { } } -/// The sigmoid activation function. -pub fn sigmoid(n: f32) -> f32 { - 1. / (1. + std::f32::consts::E.powf(-n)) +bitflags! { + /// Specifies where an activation function can occur + #[derive(Copy, Clone, Debug, Eq, PartialEq)] + pub struct NeuronScope: u8 { + /// Whether the activation can be applied to the input layer. + const INPUT = 0b001; + + /// Whether the activation can be applied to the hidden layer. + const HIDDEN = 0b010; + + /// Whether the activation can be applied to the output layer. + const OUTPUT = 0b100; + + /// The activation function will not be randomly placed anywhere + const NONE = 0b000; + } } -/// The ReLU activation function. -pub fn relu(n: f32) -> f32 { - n.max(0.) +impl Default for NeuronScope { + fn default() -> Self { + Self::HIDDEN + } } -/// Activation function that does nothing. -pub fn linear_activation(n: f32) -> f32 { - n +impl> From for NeuronScope { + fn from(value: L) -> Self { + match value.as_ref() { + NeuronLocation::Input(_) => Self::INPUT, + NeuronLocation::Hidden(_) => Self::HIDDEN, + NeuronLocation::Output(_) => Self::OUTPUT, + } + } } diff --git a/src/activation/builtin.rs b/src/activation/builtin.rs new file mode 100644 index 0000000..fdf7ab7 --- /dev/null +++ b/src/activation/builtin.rs @@ -0,0 +1,14 @@ +/// The sigmoid activation function. Scales all values nonlinearly in the range of 1 to -1. +pub fn sigmoid(n: f32) -> f32 { + 1. / (1. + std::f32::consts::E.powf(-n)) +} + +/// The ReLU activation function. Equal to `n.max(0)`` +pub fn relu(n: f32) -> f32 { + n.max(0.) +} + +/// Activation function that does nothing. +pub fn linear_activation(n: f32) -> f32 { + n +} diff --git a/src/lib.rs b/src/lib.rs index 0dd0b8c..0de7360 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,90 +1,21 @@ -//! A simple crate that implements the Neuroevolution Augmenting Topologies algorithm using [genetic-rs](https://crates.io/crates/genetic-rs) -//! ### Feature Roadmap: -//! - [x] base (single-core) crate -//! - [x] rayon -//! - [x] serde -//! - [x] crossover +//! A crate implementing NeuroEvolution of Augmenting Topologies (NEAT). //! -//! You can get started by looking at [genetic-rs docs](https://docs.rs/genetic-rs) and checking the examples for this crate. +//! The goal is to provide a simple-to-use, very dynamic [`NeuralNetwork`] type that +//! integrates directly into the [`genetic-rs`](https://crates.io/crates/genetic-rs) ecosystem. +//! +//! Look at the README, docs, or examples to learn how to use this crate. #![warn(missing_docs)] -#![cfg_attr(docsrs, feature(doc_cfg))] -/// A module containing the [`NeuralNetworkTopology`] struct. This is what you want to use in the DNA of your agent, as it is the thing that goes through nextgens and suppors mutation. -pub mod topology; +/// Contains the types surrounding activation functions. +pub mod activation; -/// A module containing the main [`NeuralNetwork`] struct. -/// This has state/cache and will run the predictions. Make sure to run [`NeuralNetwork::flush_state`] between uses of [`NeuralNetwork::predict`]. -pub mod runnable; +/// Contains the [`NeuralNetwork`] and related types. +pub mod neuralnet; -pub use genetic_rs::prelude::*; -pub use runnable::*; -pub use topology::*; +pub use neuralnet::*; -#[cfg(feature = "serde")] -pub use nnt_serde::*; +pub use genetic_rs::{self, prelude::*}; #[cfg(test)] -mod tests { - use super::*; - use rand::prelude::*; - - #[derive(RandomlyMutable, DivisionReproduction, Clone)] - struct AgentDNA { - network: NeuralNetworkTopology<2, 1>, - } - - impl Prunable for AgentDNA {} - - impl GenerateRandom for AgentDNA { - fn gen_random(rng: &mut impl Rng) -> Self { - Self { - network: NeuralNetworkTopology::new(0.01, 3, rng), - } - } - } - - #[test] - fn basic_test() { - let fitness = |g: &AgentDNA| { - let network = NeuralNetwork::from(&g.network); - let mut fitness = 0.; - let mut rng = rand::thread_rng(); - - for _ in 0..100 { - let n = rng.gen::() * 10000.; - let base = rng.gen::() * 10.; - let expected = n.log(base); - - let [answer] = network.predict([n, base]); - network.flush_state(); - - fitness += 5. / (answer - expected).abs(); - } - - fitness - }; - - #[cfg(not(feature = "rayon"))] - let mut rng = rand::thread_rng(); - - let mut sim = GeneticSim::new( - #[cfg(not(feature = "rayon"))] - Vec::gen_random(&mut rng, 100), - #[cfg(feature = "rayon")] - Vec::gen_random(100), - fitness, - division_pruning_nextgen, - ); - - for _ in 0..100 { - sim.next_generation(); - } - - let mut fits: Vec<_> = sim.genomes.iter().map(fitness).collect(); - - fits.sort_by(|a, b| a.partial_cmp(&b).unwrap()); - - dbg!(fits); - } -} +mod tests; diff --git a/src/neuralnet.rs b/src/neuralnet.rs new file mode 100644 index 0000000..6f5f25d --- /dev/null +++ b/src/neuralnet.rs @@ -0,0 +1,856 @@ +use std::{ + collections::HashSet, + sync::{ + atomic::{AtomicBool, AtomicUsize, Ordering}, + Arc, + }, +}; + +use atomic_float::AtomicF32; +use genetic_rs::prelude::*; +use rand::Rng; +use replace_with::replace_with_or_abort; + +use crate::{ + activation::{builtin::*, *}, + activation_fn, +}; + +use rayon::prelude::*; + +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +#[cfg(feature = "serde")] +use serde_big_array::BigArray; + +/// The mutation settings for [`NeuralNetwork`]. +/// Does not affect [`NeuralNetwork::mutate`], only [`NeuralNetwork::divide`] and [`NeuralNetwork::crossover`]. +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone, PartialEq)] +pub struct MutationSettings { + /// The chance of each mutation type to occur. + pub mutation_rate: f32, + + /// The number of times to try to mutate the network. + pub mutation_passes: usize, + + /// The maximum amount that the weights will be mutated by. + pub weight_mutation_amount: f32, +} + +impl Default for MutationSettings { + fn default() -> Self { + Self { + mutation_rate: 0.01, + mutation_passes: 3, + weight_mutation_amount: 0.5, + } + } +} + +/// An abstract neural network type with `I` input neurons and `O` output neurons. +/// Hidden neurons are not organized into layers, but rather float and link freely +/// (or at least in any way that doesn't cause a cyclic dependency). +/// +/// See [`NeuralNetwork::predict`] for usage. +#[derive(Debug, Clone, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct NeuralNetwork { + /// The input layer of neurons. Values specified in [`NeuralNetwork::predict`] will start here. + #[cfg_attr(feature = "serde", serde(with = "BigArray"))] + pub input_layer: [Neuron; I], + + /// The hidden layer(s) of neurons. They are not actually layered, but rather free-floating. + pub hidden_layers: Vec, + + /// The output layer of neurons. Their values will be returned from [`NeuralNetwork::predict`]. + #[cfg_attr(feature = "serde", serde(with = "BigArray"))] + pub output_layer: [Neuron; O], + + /// The mutation settings for the network. + pub mutation_settings: MutationSettings, +} + +impl NeuralNetwork { + // TODO option to set default output layer activations + /// Creates a new random neural network with the given settings. + pub fn new(mutation_settings: MutationSettings, rng: &mut impl Rng) -> Self { + let mut output_layer = Vec::with_capacity(O); + + for _ in 0..O { + output_layer.push(Neuron::new_with_activation( + vec![], + activation_fn!(sigmoid), + rng, + )); + } + + let mut input_layer = Vec::with_capacity(I); + + for _ in 0..I { + let mut already_chosen = Vec::new(); + let outputs = (0..rng.gen_range(1..=O)) + .map(|_| { + let mut j = rng.gen_range(0..O); + while already_chosen.contains(&j) { + j = rng.gen_range(0..O); + } + + output_layer[j].input_count += 1; + already_chosen.push(j); + + (NeuronLocation::Output(j), rng.gen()) + }) + .collect(); + + input_layer.push(Neuron::new_with_activation( + outputs, + activation_fn!(linear_activation), + rng, + )); + } + + let input_layer = input_layer.try_into().unwrap(); + let output_layer = output_layer.try_into().unwrap(); + + Self { + input_layer, + hidden_layers: vec![], + output_layer, + mutation_settings, + } + } + + /// Runs the neural network, propagating values from input to output layer. + pub fn predict(&self, inputs: [f32; I]) -> [f32; O] { + let cache = Arc::new(NeuralNetCache::from(self)); + cache.prime_inputs(inputs); + + (0..I) + .into_par_iter() + .for_each(|i| self.eval(NeuronLocation::Input(i), cache.clone())); + + cache.output() + } + + fn eval(&self, loc: impl AsRef, cache: Arc>) { + let loc = loc.as_ref(); + + if !cache.claim(loc) { + // some other thread is already + // waiting to do this task, currently doing it, or done. + // no need to do it again. + return; + } + + let loc = loc.as_ref(); + while !cache.is_ready(loc) { + // essentially spinlocks until the dependency tasks are complete, + // while letting this thread do some work on random tasks. + rayon::yield_now(); + } + + let val = cache.get(loc); + let n = self.get_neuron(loc); + + n.outputs.par_iter().for_each(|(loc2, weight)| { + cache.add(loc2, n.activate(val * weight)); + self.eval(loc2, cache.clone()); + }); + } + + /// Get a neuron at the specified [`NeuronLocation`]. + pub fn get_neuron(&self, loc: impl AsRef) -> &Neuron { + match loc.as_ref() { + NeuronLocation::Input(i) => &self.input_layer[*i], + NeuronLocation::Hidden(i) => &self.hidden_layers[*i], + NeuronLocation::Output(i) => &self.output_layer[*i], + } + } + + /// Get a mutable reference to the neuron at the specified [`NeuronLocation`]. + pub fn get_neuron_mut(&mut self, loc: impl AsRef) -> &mut Neuron { + match loc.as_ref() { + NeuronLocation::Input(i) => &mut self.input_layer[*i], + NeuronLocation::Hidden(i) => &mut self.hidden_layers[*i], + NeuronLocation::Output(i) => &mut self.output_layer[*i], + } + } + + /// Split a [`Connection`] into two of the same weight, joined by a new [`Neuron`] in the hidden layer(s). + pub fn split_connection(&mut self, connection: Connection, rng: &mut impl Rng) { + let newloc = NeuronLocation::Hidden(self.hidden_layers.len()); + + let a = self.get_neuron_mut(connection.from); + let weight = unsafe { a.remove_connection(connection.to) }.unwrap(); + + a.outputs.push((newloc, weight)); + + let n = Neuron::new(vec![(connection.to, weight)], NeuronScope::HIDDEN, rng); + self.hidden_layers.push(n); + } + + /// Adds a connection but does not check for cyclic linkages. + /// Marked as unsafe because it could cause a hang/livelock when predicting due to cyclic linkage. + /// There is no actual UB or unsafe code associated with it. + pub unsafe fn add_connection_raw(&mut self, connection: Connection, weight: f32) { + let a = self.get_neuron_mut(connection.from); + a.outputs.push((connection.to, weight)); + + // let b = self.get_neuron_mut(connection.to); + // b.inputs.insert(connection.from); + } + + /// Returns false if the connection is cyclic. + pub fn is_connection_safe(&self, connection: Connection) -> bool { + let mut visited = HashSet::from([connection.from]); + + self.dfs(&mut visited, connection.to) + } + + // TODO maybe parallelize + fn dfs(&self, visited: &mut HashSet, current: NeuronLocation) -> bool { + if !visited.insert(current) { + return false; + } + + let n = self.get_neuron(current); + for (loc, _) in &n.outputs { + if !self.dfs(visited, *loc) { + return false; + } + } + + true + } + + /// Safe, checked add connection method. Returns false if it aborted connecting due to cyclic linkage. + pub fn add_connection(&mut self, connection: Connection, weight: f32) -> bool { + if !self.is_connection_safe(connection) { + return false; + } + + unsafe { + self.add_connection_raw(connection, weight); + } + + true + } + + /// Mutates a connection's weight. + pub fn mutate_weight(&mut self, connection: Connection, rng: &mut impl Rng) { + let rate = self.mutation_settings.weight_mutation_amount; + let n = self.get_neuron_mut(connection.from); + n.mutate_weight(connection.to, rate, rng).unwrap(); + } + + /// Get a random valid location within the network. + pub fn random_location(&self, rng: &mut impl Rng) -> NeuronLocation { + match rng.gen_range(0..3) { + 0 => NeuronLocation::Input(rng.gen_range(0..self.input_layer.len())), + 1 => NeuronLocation::Hidden(rng.gen_range(0..self.hidden_layers.len())), + 2 => NeuronLocation::Output(rng.gen_range(0..self.output_layer.len())), + _ => unreachable!(), + } + } + + /// Get a random valid location within a [`NeuronScope`]. + pub fn random_location_in_scope( + &self, + rng: &mut impl Rng, + scope: NeuronScope, + ) -> NeuronLocation { + let loc = self.random_location(rng); + + // this is a lazy and slow way of donig it, TODO better version. + if !scope.contains(NeuronScope::from(loc)) { + return self.random_location_in_scope(rng, scope); + } + + loc + } + + /// Remove a connection and any hanging neurons caused by the deletion. + /// Returns whether there was a hanging neuron. + pub fn remove_connection(&mut self, connection: Connection) -> bool { + let a = self.get_neuron_mut(connection.from); + unsafe { a.remove_connection(connection.to) }.unwrap(); + + let b = self.get_neuron_mut(connection.to); + b.input_count -= 1; + + if b.input_count <= 0 { + self.remove_neuron(connection.to); + return true; + } + + false + } + + /// Remove a neuron and downshift all connection indexes to compensate for it. + pub fn remove_neuron(&mut self, loc: impl AsRef) { + let loc = loc.as_ref(); + if !loc.is_hidden() { + panic!("Can only remove neurons from hidden layer"); + } + + unsafe { + self.downshift_connections(loc.unwrap()); + } + } + + unsafe fn downshift_connections(&mut self, i: usize) { + self.input_layer + .par_iter_mut() + .for_each(|n| n.downshift_outputs(i)); + + self.hidden_layers + .par_iter_mut() + .for_each(|n| n.downshift_outputs(i)); + } + + // TODO maybe more parallelism and pass Connection info. + /// Runs the `callback` on the weights of the neural network in parallel, allowing it to modify weight values. + pub fn map_weights(&mut self, callback: impl Fn(&mut f32) + Sync) { + for n in &mut self.input_layer { + n.outputs.par_iter_mut().for_each(|(_, w)| callback(w)); + } + + for n in &mut self.hidden_layers { + n.outputs.par_iter_mut().for_each(|(_, w)| callback(w)); + } + } + + unsafe fn clear_input_counts(&mut self) { + // not sure whether all this parallelism is necessary or if it will just generate overhead + // rayon::scope(|s| { + // s.spawn(|_| self.input_layer.par_iter_mut().for_each(|n| n.input_count = 0)); + // s.spawn(|_| self.hidden_layers.par_iter_mut().for_each(|n| n.input_count = 0)); + // s.spawn(|_| self.output_layer.par_iter_mut().for_each(|n| n.input_count = 0)); + // }); + + self.input_layer + .par_iter_mut() + .for_each(|n| n.input_count = 0); + self.hidden_layers + .par_iter_mut() + .for_each(|n| n.input_count = 0); + self.output_layer + .par_iter_mut() + .for_each(|n| n.input_count = 0); + } + + /// Recalculates the [`input_count`][`Neuron::input_count`] field for all neurons in the network. + pub fn recalculate_input_counts(&mut self) { + unsafe { self.clear_input_counts() }; + + for i in 0..I { + for j in 0..self.input_layer[i].outputs.len() { + let (loc, _) = self.input_layer[i].outputs[j]; + self.get_neuron_mut(loc).input_count += 1; + } + } + + for i in 0..self.hidden_layers.len() { + for j in 0..self.hidden_layers[i].outputs.len() { + let (loc, _) = self.hidden_layers[i].outputs[j]; + self.get_neuron_mut(loc).input_count += 1; + } + } + } +} + +impl RandomlyMutable for NeuralNetwork { + fn mutate(&mut self, rate: f32, rng: &mut impl Rng) { + if rng.gen::() <= rate { + // split connection + let from = self.random_location_in_scope(rng, !NeuronScope::OUTPUT); + let n = self.get_neuron(from); + let (to, _) = n.random_output(rng); + + self.split_connection(Connection { from, to }, rng); + } + + if rng.gen::() <= rate { + // add connection + let weight = rng.gen::(); + + let from = self.random_location_in_scope(rng, !NeuronScope::OUTPUT); + let to = self.random_location_in_scope(rng, !NeuronScope::INPUT); + + let mut connection = Connection { from, to }; + while !self.add_connection(connection, weight) { + let from = self.random_location_in_scope(rng, !NeuronScope::OUTPUT); + let to = self.random_location_in_scope(rng, !NeuronScope::INPUT); + connection = Connection { from, to }; + } + } + + if rng.gen::() <= rate { + // remove connection + + let from = self.random_location_in_scope(rng, !NeuronScope::OUTPUT); + let a = self.get_neuron(from); + let (to, _) = a.random_output(rng); + + self.remove_connection(Connection { from, to }); + } + + self.map_weights(|w| { + // TODO maybe `Send`able rng. + let mut rng = rand::thread_rng(); + + if rng.gen::() <= rate { + *w += rng.gen_range(-rate..rate); + } + }); + } +} + +impl DivisionReproduction for NeuralNetwork { + fn divide(&self, rng: &mut impl Rng) -> Self { + let mut child = self.clone(); + + for _ in 0..self.mutation_settings.mutation_passes { + child.mutate(child.mutation_settings.mutation_rate, rng); + } + + child + } +} + +impl CrossoverReproduction for NeuralNetwork { + fn crossover(&self, other: &Self, rng: &mut impl rand::Rng) -> Self { + let mut output_layer = self.output_layer.clone(); + + for (i, n) in output_layer.iter_mut().enumerate() { + if rng.gen::() >= 0.5 { + *n = other.output_layer[i].clone(); + } + } + + let hidden_len = self.hidden_layers.len().max(other.hidden_layers.len()); + let mut hidden_layers = Vec::with_capacity(hidden_len); + + for i in 0..hidden_len { + if rng.gen::() >= 0.5 { + if let Some(n) = self.hidden_layers.get(i) { + let mut n = n.clone(); + n.prune_invalid_outputs(hidden_len, O); + + hidden_layers[i] = n; + + continue; + } + } + + let mut n = other.hidden_layers[i].clone(); + n.prune_invalid_outputs(hidden_len, O); + + hidden_layers[i] = n; + } + + let mut input_layer = self.input_layer.clone(); + + for (i, n) in input_layer.iter_mut().enumerate() { + if rng.gen::() >= 0.5 { + *n = other.input_layer[i].clone(); + } + n.prune_invalid_outputs(hidden_len, O); + } + + // crossover mutation settings just in case. + let mutation_settings = if rng.gen::() >= 0.5 { + self.mutation_settings.clone() + } else { + other.mutation_settings.clone() + }; + + let mut child = Self { + input_layer, + hidden_layers, + output_layer, + mutation_settings, + }; + + // TODO maybe find a way to do this while doing crossover stuff instead of recalculating everything. + // would be annoying to implement though. + child.recalculate_input_counts(); + + for _ in 0..child.mutation_settings.mutation_passes { + child.mutate(child.mutation_settings.mutation_rate, rng); + } + + child + } +} + +fn output_exists(loc: NeuronLocation, hidden_len: usize, output_len: usize) -> bool { + match loc { + NeuronLocation::Input(_) => false, + NeuronLocation::Hidden(i) => i < hidden_len, + NeuronLocation::Output(i) => i < output_len, + } +} + +/// A helper struct for operations on connections between neurons. +/// It does not contain information about the weight. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct Connection { + /// The source of the connection. + pub from: NeuronLocation, + + /// The destination of the connection. + pub to: NeuronLocation, +} + +/// A stateless neuron. Contains info about bias, activation, and connections. +#[derive(Debug, Clone, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct Neuron { + /// The input count used in [`NeuralNetCache`]. Not safe to modify. + pub input_count: usize, + + /// The connections and weights to other neurons. + pub outputs: Vec<(NeuronLocation, f32)>, + + /// The initial value of the neuron. + pub bias: f32, + + /// The activation function applied to the value before propagating to [`outputs`][Neuron::outputs]. + pub activation_fn: ActivationFn, +} + +impl Neuron { + /// Creates a new neuron with a specified activation function and outputs. + pub fn new_with_activation( + outputs: Vec<(NeuronLocation, f32)>, + activation_fn: ActivationFn, + rng: &mut impl Rng, + ) -> Self { + Self { + input_count: 0, + outputs, + bias: rng.gen(), + activation_fn, + } + } + + /// Creates a new neuron with the given output locations. + /// Chooses a random activation function within the specified scope. + pub fn new( + outputs: Vec<(NeuronLocation, f32)>, + current_scope: NeuronScope, + rng: &mut impl Rng, + ) -> Self { + let reg = ACTIVATION_REGISTRY.read().unwrap(); + let activations = reg.activations_in_scope(current_scope); + + Self::new_with_activations(outputs, activations, rng) + } + + /// Creates a new neuron with the given outputs. + /// Takes a collection of activation functions and chooses a random one from them to use. + pub fn new_with_activations( + outputs: Vec<(NeuronLocation, f32)>, + activations: impl IntoIterator, + rng: &mut impl Rng, + ) -> Self { + // TODO get random in iterator form + let mut activations: Vec<_> = activations.into_iter().collect(); + + // TODO maybe Result instead. + if activations.is_empty() { + panic!("Empty activations list provided"); + } + + Self::new_with_activation( + outputs, + activations.remove(rng.gen_range(0..activations.len())), + rng, + ) + } + + /// Runs the [activation function][Neuron::activation_fn] on the given value and returns it. + pub fn activate(&self, v: f32) -> f32 { + self.activation_fn.func.activate(v) + } + + /// Get the weight of the provided output location. Returns `None` if not found. + pub fn get_weight(&self, output: impl AsRef) -> Option { + let loc = *output.as_ref(); + for out in &self.outputs { + if out.0 == loc { + return Some(out.1); + } + } + + None + } + + /// Tries to remove a connection from the neuron and returns the weight if it was found. + /// Marked as unsafe because it will not update the destination's [`input_count`][Neuron::input_count]. + pub unsafe fn remove_connection(&mut self, output: impl AsRef) -> Option { + let loc = *output.as_ref(); + let mut i = 0; + + while i < self.outputs.len() { + if self.outputs[i].0 == loc { + return Some(self.outputs.remove(i).1); + } + i += 1; + } + + None + } + + /// Randomly mutates the specified weight with the rate. + pub fn mutate_weight( + &mut self, + output: impl AsRef, + rate: f32, + rng: &mut impl Rng, + ) -> Option { + let loc = *output.as_ref(); + let mut i = 0; + + while i < self.outputs.len() { + let o = &mut self.outputs[i]; + if o.0 == loc { + o.1 += rng.gen_range(-rate..rate); + + return Some(o.1); + } + + i += 1; + } + + None + } + + /// Get a random output location and weight. + pub fn random_output(&self, rng: &mut impl Rng) -> (NeuronLocation, f32) { + self.outputs[rng.gen_range(0..self.outputs.len())] + } + + pub(crate) fn downshift_outputs(&mut self, i: usize) { + // TODO par_iter_mut instead of replace + replace_with_or_abort(&mut self.outputs, |o| { + o.into_par_iter() + .map(|(loc, w)| match loc { + NeuronLocation::Hidden(j) if j > i => (NeuronLocation::Hidden(j - 1), w), + _ => (loc, w), + }) + .collect() + }); + } + + /// Removes any outputs pointing to a nonexistent neuron. + pub fn prune_invalid_outputs(&mut self, hidden_len: usize, output_len: usize) { + self.outputs + .retain(|(loc, _)| output_exists(*loc, hidden_len, output_len)); + } +} + +/// A pseudo-pointer of sorts that is used for caching. +#[derive(Hash, Clone, Copy, Debug, Eq, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum NeuronLocation { + /// Points to a neuron in the input layer at contained index. + Input(usize), + + /// Points to a neuron in the hidden layer at contained index. + Hidden(usize), + + /// Points to a neuron in the output layer at contained index. + Output(usize), +} + +impl NeuronLocation { + /// Returns `true` if it points to the input layer. Otherwise, returns `false`. + pub fn is_input(&self) -> bool { + matches!(self, Self::Input(_)) + } + + /// Returns `true` if it points to the hidden layer. Otherwise, returns `false`. + pub fn is_hidden(&self) -> bool { + matches!(self, Self::Hidden(_)) + } + + /// Returns `true` if it points to the output layer. Otherwise, returns `false`. + pub fn is_output(&self) -> bool { + matches!(self, Self::Output(_)) + } + + /// Retrieves the index value, regardless of layer. Does not consume. + pub fn unwrap(&self) -> usize { + match self { + Self::Input(i) => *i, + Self::Hidden(i) => *i, + Self::Output(i) => *i, + } + } +} + +impl AsRef for NeuronLocation { + fn as_ref(&self) -> &NeuronLocation { + self + } +} + +/// Handles the state of a single neuron for [`NeuralNetCache`]. +#[derive(Debug, Default)] +pub struct NeuronCache { + /// The value of the neuron. + pub value: AtomicF32, + + /// The expected input count. + pub expected_inputs: usize, + + /// The number of inputs that have finished evaluating. + pub finished_inputs: AtomicUsize, + + /// Whether or not a thread has claimed this neuron to work on it. + pub claimed: AtomicBool, +} + +impl NeuronCache { + /// Creates a new [`NeuronCache`] given relevant info. + /// Use [`NeuronCache::from`] instead to create cache for a [`Neuron`]. + pub fn new(bias: f32, expected_inputs: usize) -> Self { + Self { + value: AtomicF32::new(bias), + expected_inputs, + ..Default::default() + } + } +} + +impl From<&Neuron> for NeuronCache { + fn from(value: &Neuron) -> Self { + Self { + value: AtomicF32::new(value.bias), + expected_inputs: value.input_count, + finished_inputs: AtomicUsize::new(0), + claimed: AtomicBool::new(false), + } + } +} + +/// A cache type used in [`NeuralNetwork::predict`] to track state. +#[derive(Debug)] +pub struct NeuralNetCache { + /// The input layer cache. + pub input_layer: [NeuronCache; I], + + /// The hidden layer(s) cache. + pub hidden_layers: Vec, + + /// The output layer cache. + pub output_layer: [NeuronCache; O], +} + +impl NeuralNetCache { + /// Gets the value of a neuron at the given location. + pub fn get(&self, loc: impl AsRef) -> f32 { + match loc.as_ref() { + NeuronLocation::Input(i) => self.input_layer[*i].value.load(Ordering::SeqCst), + NeuronLocation::Hidden(i) => self.hidden_layers[*i].value.load(Ordering::SeqCst), + NeuronLocation::Output(i) => self.output_layer[*i].value.load(Ordering::SeqCst), + } + } + + /// Adds a value to the neuron at the specified location and increments [`finished_inputs`][NeuronCache::finished_inputs]. + pub fn add(&self, loc: impl AsRef, n: f32) -> f32 { + match loc.as_ref() { + NeuronLocation::Input(i) => self.input_layer[*i].value.fetch_add(n, Ordering::SeqCst), + NeuronLocation::Hidden(i) => { + let c = &self.hidden_layers[*i]; + let v = c.value.fetch_add(n, Ordering::SeqCst); + c.finished_inputs.fetch_add(1, Ordering::SeqCst); + v + } + NeuronLocation::Output(i) => { + let c = &self.output_layer[*i]; + let v = c.value.fetch_add(n, Ordering::SeqCst); + c.finished_inputs.fetch_add(1, Ordering::SeqCst); + v + } + } + } + + /// Returns whether [`finished_inputs`][NeuronCache::finished_inputs] matches [`expected_inputs`][NeuronCache::expected_inputs]. + pub fn is_ready(&self, loc: impl AsRef) -> bool { + match loc.as_ref() { + NeuronLocation::Input(i) => { + let c = &self.input_layer[*i]; + c.expected_inputs >= c.finished_inputs.load(Ordering::SeqCst) + } + NeuronLocation::Hidden(i) => { + let c = &self.hidden_layers[*i]; + c.expected_inputs >= c.finished_inputs.load(Ordering::SeqCst) + } + NeuronLocation::Output(i) => { + let c = &self.output_layer[*i]; + c.expected_inputs >= c.finished_inputs.load(Ordering::SeqCst) + } + } + } + + /// Adds the input values to the input layer of neurons. + pub fn prime_inputs(&self, inputs: [f32; I]) { + for (i, v) in inputs.into_iter().enumerate() { + self.input_layer[i].value.fetch_add(v, Ordering::SeqCst); + } + } + + /// Fetches and packs the output layer values into an array. + pub fn output(&self) -> [f32; O] { + let output: Vec<_> = self + .output_layer + .par_iter() + .map(|c| c.value.load(Ordering::SeqCst)) + .collect(); + + output.try_into().unwrap() + } + + /// Attempts to claim a neuron. Returns false if it has already been claimed. + pub fn claim(&self, loc: impl AsRef) -> bool { + match loc.as_ref() { + NeuronLocation::Input(i) => self.input_layer[*i] + .claimed + .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire) + .is_ok(), + NeuronLocation::Hidden(i) => self.hidden_layers[*i] + .claimed + .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire) + .is_ok(), + NeuronLocation::Output(i) => self.output_layer[*i] + .claimed + .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire) + .is_ok(), + } + } +} + +impl From<&NeuralNetwork> for NeuralNetCache { + fn from(net: &NeuralNetwork) -> Self { + let input_layer: Vec<_> = net.input_layer.par_iter().map(|n| n.into()).collect(); + let input_layer = input_layer.try_into().unwrap(); + + let hidden_layers: Vec<_> = net.hidden_layers.par_iter().map(|n| n.into()).collect(); + let hidden_layers = hidden_layers.try_into().unwrap(); + + let output_layer: Vec<_> = net.output_layer.par_iter().map(|n| n.into()).collect(); + let output_layer = output_layer.try_into().unwrap(); + + Self { + input_layer, + hidden_layers, + output_layer, + } + } +} diff --git a/src/runnable.rs b/src/runnable.rs deleted file mode 100644 index 5b28f54..0000000 --- a/src/runnable.rs +++ /dev/null @@ -1,300 +0,0 @@ -use crate::topology::*; - -#[cfg(not(feature = "rayon"))] -use std::{cell::RefCell, rc::Rc}; - -#[cfg(feature = "rayon")] -use rayon::prelude::*; -#[cfg(feature = "rayon")] -use std::sync::{Arc, RwLock}; - -/// A runnable, stated Neural Network generated from a [NeuralNetworkTopology]. Use [`NeuralNetwork::from`] to go from stateles to runnable. -/// Because this has state, you need to run [`NeuralNetwork::flush_state`] between [`NeuralNetwork::predict`] calls. -#[derive(Debug)] -#[cfg(not(feature = "rayon"))] -pub struct NeuralNetwork { - input_layer: [Rc>; I], - hidden_layers: Vec>>, - output_layer: [Rc>; O], -} - -/// Parallelized version of the [`NeuralNetwork`] struct. -#[derive(Debug)] -#[cfg(feature = "rayon")] -pub struct NeuralNetwork { - input_layer: [Arc>; I], - hidden_layers: Vec>>, - output_layer: [Arc>; O], -} - -impl NeuralNetwork { - /// Predicts an output for the given inputs. - #[cfg(not(feature = "rayon"))] - pub fn predict(&self, inputs: [f32; I]) -> [f32; O] { - for (i, v) in inputs.iter().enumerate() { - let mut nw = self.input_layer[i].borrow_mut(); - nw.state.value = *v; - nw.state.processed = true; - } - - (0..O) - .map(NeuronLocation::Output) - .map(|loc| self.process_neuron(loc)) - .collect::>() - .try_into() - .unwrap() - } - - /// Parallelized prediction of outputs from inputs. - #[cfg(feature = "rayon")] - pub fn predict(&self, inputs: [f32; I]) -> [f32; O] { - inputs.par_iter().enumerate().for_each(|(i, v)| { - let mut nw = self.input_layer[i].write().unwrap(); - nw.state.value = *v; - nw.state.processed = true; - }); - - (0..O) - .map(NeuronLocation::Output) - .collect::>() - .into_par_iter() - .map(|loc| self.process_neuron(loc)) - .collect::>() - .try_into() - .unwrap() - } - - #[cfg(not(feature = "rayon"))] - fn process_neuron(&self, loc: NeuronLocation) -> f32 { - let n = self.get_neuron(loc); - - { - let nr = n.borrow(); - - if nr.state.processed { - return nr.state.value; - } - } - - let mut n = n.borrow_mut(); - - for (l, w) in n.inputs.clone() { - n.state.value += self.process_neuron(l) * w; - } - - n.activate(); - - n.state.value - } - - #[cfg(feature = "rayon")] - fn process_neuron(&self, loc: NeuronLocation) -> f32 { - let n = self.get_neuron(loc); - - { - let nr = n.read().unwrap(); - - if nr.state.processed { - return nr.state.value; - } - } - - let val: f32 = n - .read() - .unwrap() - .inputs - .par_iter() - .map(|&(n2, w)| { - let processed = self.process_neuron(n2); - processed * w - }) - .sum(); - - let mut nw = n.write().unwrap(); - nw.state.value += val; - nw.activate(); - - nw.state.value - } - - #[cfg(not(feature = "rayon"))] - fn get_neuron(&self, loc: NeuronLocation) -> Rc> { - match loc { - NeuronLocation::Input(i) => self.input_layer[i].clone(), - NeuronLocation::Hidden(i) => self.hidden_layers[i].clone(), - NeuronLocation::Output(i) => self.output_layer[i].clone(), - } - } - - #[cfg(feature = "rayon")] - fn get_neuron(&self, loc: NeuronLocation) -> Arc> { - match loc { - NeuronLocation::Input(i) => self.input_layer[i].clone(), - NeuronLocation::Hidden(i) => self.hidden_layers[i].clone(), - NeuronLocation::Output(i) => self.output_layer[i].clone(), - } - } - - /// Flushes the network's state after a [prediction][NeuralNetwork::predict]. - #[cfg(not(feature = "rayon"))] - pub fn flush_state(&self) { - for n in &self.input_layer { - n.borrow_mut().flush_state(); - } - - for n in &self.hidden_layers { - n.borrow_mut().flush_state(); - } - - for n in &self.output_layer { - n.borrow_mut().flush_state(); - } - } - - /// Flushes the neural network's state. - #[cfg(feature = "rayon")] - pub fn flush_state(&self) { - self.input_layer - .par_iter() - .for_each(|n| n.write().unwrap().flush_state()); - - self.hidden_layers - .par_iter() - .for_each(|n| n.write().unwrap().flush_state()); - - self.output_layer - .par_iter() - .for_each(|n| n.write().unwrap().flush_state()); - } -} - -impl From<&NeuralNetworkTopology> for NeuralNetwork { - #[cfg(not(feature = "rayon"))] - fn from(value: &NeuralNetworkTopology) -> Self { - let input_layer = value - .input_layer - .iter() - .map(|n| Rc::new(RefCell::new(Neuron::from(&n.read().unwrap().clone())))) - .collect::>() - .try_into() - .unwrap(); - - let hidden_layers = value - .hidden_layers - .iter() - .map(|n| Rc::new(RefCell::new(Neuron::from(&n.read().unwrap().clone())))) - .collect(); - - let output_layer = value - .output_layer - .iter() - .map(|n| Rc::new(RefCell::new(Neuron::from(&n.read().unwrap().clone())))) - .collect::>() - .try_into() - .unwrap(); - - Self { - input_layer, - hidden_layers, - output_layer, - } - } - - #[cfg(feature = "rayon")] - fn from(value: &NeuralNetworkTopology) -> Self { - let input_layer = value - .input_layer - .iter() - .map(|n| Arc::new(RwLock::new(Neuron::from(&n.read().unwrap().clone())))) - .collect::>() - .try_into() - .unwrap(); - - let hidden_layers = value - .hidden_layers - .iter() - .map(|n| Arc::new(RwLock::new(Neuron::from(&n.read().unwrap().clone())))) - .collect(); - - let output_layer = value - .output_layer - .iter() - .map(|n| Arc::new(RwLock::new(Neuron::from(&n.read().unwrap().clone())))) - .collect::>() - .try_into() - .unwrap(); - - Self { - input_layer, - hidden_layers, - output_layer, - } - } -} - -/// A state-filled neuron. -#[derive(Clone, Debug)] -pub struct Neuron { - inputs: Vec<(NeuronLocation, f32)>, - bias: f32, - - /// The current state of the neuron. - pub state: NeuronState, - - /// The neuron's activation function - pub activation: ActivationFn, -} - -impl Neuron { - /// Flushes a neuron's state. Called by [`NeuralNetwork::flush_state`] - pub fn flush_state(&mut self) { - self.state.value = self.bias; - } - - /// Applies the activation function to the neuron - pub fn activate(&mut self) { - self.state.value = self.activation.func.activate(self.state.value); - } -} - -impl From<&NeuronTopology> for Neuron { - fn from(value: &NeuronTopology) -> Self { - Self { - inputs: value.inputs.clone(), - bias: value.bias, - state: NeuronState { - value: value.bias, - ..Default::default() - }, - activation: value.activation.clone(), - } - } -} - -/// A state used in [`Neuron`]s for cache. -#[derive(Clone, Debug, Default)] -pub struct NeuronState { - /// The current value of the neuron. Initialized to a neuron's bias when flushed. - pub value: f32, - - /// Whether or not [`value`][NeuronState::value] has finished processing. - pub processed: bool, -} - -/// A blanket trait for iterators meant to help with interpreting the output of a [`NeuralNetwork`] -#[cfg(feature = "max-index")] -pub trait MaxIndex { - /// Retrieves the index of the max value. - fn max_index(self) -> usize; -} - -#[cfg(feature = "max-index")] -impl, T: PartialOrd> MaxIndex for I { - // slow and lazy implementation but it works (will prob optimize in the future) - fn max_index(self) -> usize { - self.enumerate() - .max_by(|(_, v), (_, v2)| v.partial_cmp(v2).unwrap()) - .unwrap() - .0 - } -} diff --git a/src/tests.rs b/src/tests.rs new file mode 100644 index 0000000..825cdee --- /dev/null +++ b/src/tests.rs @@ -0,0 +1,179 @@ +use crate::*; +use rand::prelude::*; + +// no support for tuple structs derive in genetic-rs yet :( +#[derive(Debug, Clone, PartialEq)] +struct Agent(NeuralNetwork<4, 1>); + +impl Prunable for Agent {} + +impl RandomlyMutable for Agent { + fn mutate(&mut self, rate: f32, rng: &mut impl Rng) { + self.0.mutate(rate, rng); + } +} + +impl DivisionReproduction for Agent { + fn divide(&self, rng: &mut impl rand::Rng) -> Self { + Self(self.0.divide(rng)) + } +} + +impl CrossoverReproduction for Agent { + fn crossover(&self, other: &Self, rng: &mut impl rand::Rng) -> Self { + Self(self.0.crossover(&other.0, rng)) + } +} + +struct GuessTheNumber(f32); + +impl GuessTheNumber { + fn new(rng: &mut impl Rng) -> Self { + Self(rng.gen()) + } + + fn guess(&self, n: f32) -> Option { + if n > self.0 + 1.0e-5 { + return Some(1.); + } + + if n < self.0 - 1.0e-5 { + return Some(-1.); + } + + // guess was correct (or at least within margin of error). + None + } +} + +fn fitness(agent: &Agent) -> f32 { + let mut rng = rand::thread_rng(); + + let mut fitness = 0.; + + // 10 games for consistency + for _ in 0..10 { + let game = GuessTheNumber::new(&mut rng); + + let mut last_guess = 0.; + let mut last_result = 0.; + + let mut last_guess_2 = 0.; + let mut last_result_2 = 0.; + + let mut steps = 0; + loop { + if steps >= 20 { + // took too many guesses + fitness -= 50.; + break; + } + + let [cur_guess] = + agent + .0 + .predict([last_guess, last_result, last_guess_2, last_result_2]); + + let cur_result = game.guess(cur_guess); + + if let Some(result) = cur_result { + last_guess = last_guess_2; + last_result = last_result_2; + + last_guess_2 = cur_guess; + last_result_2 = result; + + fitness -= 1.; + steps += 1; + + continue; + } + + fitness += 50.; + break; + } + } + + fitness +} + +#[test] +fn division() { + let mut rng = rand::thread_rng(); + + let starting_genomes = (0..100) + .map(|_| Agent(NeuralNetwork::new(MutationSettings::default(), &mut rng))) + .collect(); + + let mut sim = GeneticSim::new(starting_genomes, fitness, division_pruning_nextgen); + + sim.perform_generations(100); +} + +#[test] +fn crossover() { + let mut rng = rand::thread_rng(); + + let starting_genomes = (0..100) + .map(|_| Agent(NeuralNetwork::new(MutationSettings::default(), &mut rng))) + .collect(); + + let mut sim = GeneticSim::new(starting_genomes, fitness, crossover_pruning_nextgen); + + sim.perform_generations(100); +} + +#[cfg(feature = "serde")] +#[test] +fn serde() { + let mut rng = rand::thread_rng(); + let net: NeuralNetwork<5, 10> = NeuralNetwork::new(MutationSettings::default(), &mut rng); + + let text = serde_json::to_string(&net).unwrap(); + + let net2: NeuralNetwork<5, 10> = serde_json::from_str(&text).unwrap(); + + assert_eq!(net, net2); +} + +#[test] +fn neural_net_cache_sync() { + let cache = NeuralNetCache { + input_layer: [NeuronCache::new(0.3, 0), NeuronCache::new(0.25, 0)], + hidden_layers: vec![ + NeuronCache::new(0.2, 2), + NeuronCache::new(0.0, 2), + NeuronCache::new(1.5, 2), + ], + output_layer: [NeuronCache::new(0.0, 3), NeuronCache::new(0.0, 3)], + }; + + for i in 0..2 { + let input_loc = NeuronLocation::Input(i); + + assert!(cache.claim(&input_loc)); + + for j in 0..3 { + cache.add( + NeuronLocation::Hidden(j), + f32::tanh(cache.get(&input_loc) * 1.2), + ); + } + } + + for i in 0..3 { + let hidden_loc = NeuronLocation::Hidden(i); + + assert!(cache.is_ready(&hidden_loc)); + assert!(cache.claim(&hidden_loc)); + + for j in 0..2 { + cache.add( + NeuronLocation::Output(j), + activation::builtin::sigmoid(cache.get(&hidden_loc) * 0.7), + ); + } + } + + assert_eq!(cache.output(), [2.0688455, 2.0688455]); +} diff --git a/src/topology/mod.rs b/src/topology/mod.rs index dd246f2..e69de29 100644 --- a/src/topology/mod.rs +++ b/src/topology/mod.rs @@ -1,638 +0,0 @@ -/// Contains useful structs for serializing/deserializing a [`NeuronTopology`] -#[cfg_attr(docsrs, doc(cfg(feature = "serde")))] -#[cfg(feature = "serde")] -pub mod nnt_serde; - -/// Contains structs and traits used for activation functions. -pub mod activation; - -pub use activation::*; - -use std::{ - collections::HashSet, - sync::{Arc, RwLock}, -}; - -use genetic_rs::prelude::*; -use rand::prelude::*; - -#[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; - -use crate::activation_fn; - -/// A stateless neural network topology. -/// This is the struct you want to use in your agent's inheritance. -/// See [`NeuralNetwork::from`][crate::NeuralNetwork::from] for how to convert this to a runnable neural network. -#[derive(Debug)] -pub struct NeuralNetworkTopology { - /// The input layer of the neural network. Uses a fixed length of `I`. - pub input_layer: [Arc>; I], - - /// The hidden layers of the neural network. Because neurons have a flexible connection system, all of them exist in the same flat vector. - pub hidden_layers: Vec>>, - - /// The output layer of the neural netowrk. Uses a fixed length of `O`. - pub output_layer: [Arc>; O], - - /// The mutation rate used in [`NeuralNetworkTopology::mutate`] after crossover/division. - pub mutation_rate: f32, - - /// The number of mutation passes (and thus, maximum number of possible mutations that can occur for each entity in the generation). - pub mutation_passes: usize, -} - -impl NeuralNetworkTopology { - /// Creates a new [`NeuralNetworkTopology`]. - pub fn new(mutation_rate: f32, mutation_passes: usize, rng: &mut impl Rng) -> Self { - let input_layer: [Arc>; I] = (0..I) - .map(|_| { - Arc::new(RwLock::new(NeuronTopology::new_with_activation( - vec![], - activation_fn!(linear_activation), - rng, - ))) - }) - .collect::>() - .try_into() - .unwrap(); - - let mut output_layer = Vec::with_capacity(O); - - for _ in 0..O { - // random number of connections to random input neurons. - let input = (0..rng.gen_range(1..=I)) - .map(|_| { - let mut already_chosen = Vec::new(); - let mut i = rng.gen_range(0..I); - while already_chosen.contains(&i) { - i = rng.gen_range(0..I); - } - - already_chosen.push(i); - - NeuronLocation::Input(i) - }) - .collect(); - - output_layer.push(Arc::new(RwLock::new(NeuronTopology::new_with_activation( - input, - activation_fn!(sigmoid), - rng, - )))); - } - - let output_layer = output_layer.try_into().unwrap(); - - Self { - input_layer, - hidden_layers: vec![], - output_layer, - mutation_rate, - mutation_passes, - } - } - - /// Creates a new connection between the neurons. - /// If the connection is cyclic, it does not add a connection and returns false. - /// Otherwise, it returns true. - pub fn add_connection( - &mut self, - from: NeuronLocation, - to: NeuronLocation, - weight: f32, - ) -> bool { - if self.is_connection_cyclic(from, to) { - return false; - } - - // Add the connection since it is not cyclic - self.get_neuron(to) - .write() - .unwrap() - .inputs - .push((from, weight)); - - true - } - - fn is_connection_cyclic(&self, from: NeuronLocation, to: NeuronLocation) -> bool { - if to.is_input() || from.is_output() { - return true; - } - - // check to make sure it isn't duplicate - { - let n = self.get_neuron(to); - let n2 = n.read().unwrap(); - - for (loc, _) in &n2.inputs { - if from == *loc { - return false; - } - } - } - - let mut visited = HashSet::new(); - self.dfs(from, to, &mut visited) - } - - // TODO rayon implementation - fn dfs( - &self, - current: NeuronLocation, - target: NeuronLocation, - visited: &mut HashSet, - ) -> bool { - if current == target { - return true; - } - - visited.insert(current); - - let n = self.get_neuron(current); - let nr = n.read().unwrap(); - - for &(input, _) in &nr.inputs { - if !visited.contains(&input) && self.dfs(input, target, visited) { - return true; - } - } - - visited.remove(¤t); - false - } - - /// Gets a neuron pointer from a [`NeuronLocation`]. - /// You shouldn't ever need to directly call this unless you are doing complex custom mutations. - pub fn get_neuron(&self, loc: NeuronLocation) -> Arc> { - match loc { - NeuronLocation::Input(i) => self.input_layer[i].clone(), - NeuronLocation::Hidden(i) => self.hidden_layers[i].clone(), - NeuronLocation::Output(i) => self.output_layer[i].clone(), - } - } - - /// Gets a random neuron and its location. - pub fn rand_neuron(&self, rng: &mut impl Rng) -> (Arc>, NeuronLocation) { - match rng.gen_range(0..3) { - 0 => { - let i = rng.gen_range(0..self.input_layer.len()); - (self.input_layer[i].clone(), NeuronLocation::Input(i)) - } - 1 if !self.hidden_layers.is_empty() => { - let i = rng.gen_range(0..self.hidden_layers.len()); - (self.hidden_layers[i].clone(), NeuronLocation::Hidden(i)) - } - _ => { - let i = rng.gen_range(0..self.output_layer.len()); - (self.output_layer[i].clone(), NeuronLocation::Output(i)) - } - } - } - - fn delete_neuron(&mut self, loc: NeuronLocation) -> NeuronTopology { - if !loc.is_hidden() { - panic!("Invalid neuron deletion"); - } - - let index = loc.unwrap(); - let neuron = Arc::into_inner(self.hidden_layers.remove(index)).unwrap(); - - for n in &self.hidden_layers { - let mut nw = n.write().unwrap(); - - nw.inputs = nw - .inputs - .iter() - .filter_map(|&(input_loc, w)| { - if !input_loc.is_hidden() { - return Some((input_loc, w)); - } - - if input_loc.unwrap() == index { - return None; - } - - if input_loc.unwrap() > index { - return Some((NeuronLocation::Hidden(input_loc.unwrap() - 1), w)); - } - - Some((input_loc, w)) - }) - .collect(); - } - - for n2 in &self.output_layer { - let mut nw = n2.write().unwrap(); - nw.inputs = nw - .inputs - .iter() - .filter_map(|&(input_loc, w)| { - if !input_loc.is_hidden() { - return Some((input_loc, w)); - } - - if input_loc.unwrap() == index { - return None; - } - - if input_loc.unwrap() > index { - return Some((NeuronLocation::Hidden(input_loc.unwrap() - 1), w)); - } - - Some((input_loc, w)) - }) - .collect(); - } - - neuron.into_inner().unwrap() - } -} - -// need to do all this manually because Arcs are cringe -impl Clone for NeuralNetworkTopology { - fn clone(&self) -> Self { - let input_layer = self - .input_layer - .iter() - .map(|n| Arc::new(RwLock::new(n.read().unwrap().clone()))) - .collect::>() - .try_into() - .unwrap(); - - let hidden_layers = self - .hidden_layers - .iter() - .map(|n| Arc::new(RwLock::new(n.read().unwrap().clone()))) - .collect(); - - let output_layer = self - .output_layer - .iter() - .map(|n| Arc::new(RwLock::new(n.read().unwrap().clone()))) - .collect::>() - .try_into() - .unwrap(); - - Self { - input_layer, - hidden_layers, - output_layer, - mutation_rate: self.mutation_rate, - mutation_passes: self.mutation_passes, - } - } -} - -impl RandomlyMutable for NeuralNetworkTopology { - fn mutate(&mut self, rate: f32, rng: &mut impl rand::Rng) { - for _ in 0..self.mutation_passes { - if rng.gen::() <= rate { - // split preexisting connection - let (mut n2, _) = self.rand_neuron(rng); - - while n2.read().unwrap().inputs.is_empty() { - (n2, _) = self.rand_neuron(rng); - } - - let mut n2 = n2.write().unwrap(); - let i = rng.gen_range(0..n2.inputs.len()); - let (loc, w) = n2.inputs.remove(i); - - let loc3 = NeuronLocation::Hidden(self.hidden_layers.len()); - - let n3 = NeuronTopology::new(vec![loc], ActivationScope::HIDDEN, rng); - - self.hidden_layers.push(Arc::new(RwLock::new(n3))); - - n2.inputs.insert(i, (loc3, w)); - } - - if rng.gen::() <= rate { - // add a connection - let (_, mut loc1) = self.rand_neuron(rng); - let (_, mut loc2) = self.rand_neuron(rng); - - while loc1.is_output() || !self.add_connection(loc1, loc2, rng.gen::()) { - (_, loc1) = self.rand_neuron(rng); - (_, loc2) = self.rand_neuron(rng); - } - } - - if rng.gen::() <= rate && !self.hidden_layers.is_empty() { - // remove a neuron - let (_, mut loc) = self.rand_neuron(rng); - - while !loc.is_hidden() { - (_, loc) = self.rand_neuron(rng); - } - - // delete the neuron - self.delete_neuron(loc); - } - - if rng.gen::() <= rate { - // mutate a connection - let (mut n, _) = self.rand_neuron(rng); - - while n.read().unwrap().inputs.is_empty() { - (n, _) = self.rand_neuron(rng); - } - - let mut n = n.write().unwrap(); - let i = rng.gen_range(0..n.inputs.len()); - let (_, w) = &mut n.inputs[i]; - *w += rng.gen_range(-1.0..1.0) * rate; - } - - if rng.gen::() <= rate { - // mutate bias - let (n, _) = self.rand_neuron(rng); - let mut n = n.write().unwrap(); - - n.bias += rng.gen_range(-1.0..1.0) * rate; - } - - if rng.gen::() <= rate && !self.hidden_layers.is_empty() { - // mutate activation function - let reg = ACTIVATION_REGISTRY.read().unwrap(); - let activations = reg.activations_in_scope(ActivationScope::HIDDEN); - - let (mut n, mut loc) = self.rand_neuron(rng); - - while !loc.is_hidden() { - (n, loc) = self.rand_neuron(rng); - } - - let mut nw = n.write().unwrap(); - - // should probably not clone, but its not a huge efficiency issue anyways - nw.activation = activations[rng.gen_range(0..activations.len())].clone(); - } - } - } -} - -impl DivisionReproduction for NeuralNetworkTopology { - fn divide(&self, rng: &mut impl rand::Rng) -> Self { - let mut child = self.clone(); - child.mutate(self.mutation_rate, rng); - child - } -} - -impl PartialEq for NeuralNetworkTopology { - fn eq(&self, other: &Self) -> bool { - if self.mutation_rate != other.mutation_rate - || self.mutation_passes != other.mutation_passes - { - return false; - } - - for i in 0..I { - if *self.input_layer[i].read().unwrap() != *other.input_layer[i].read().unwrap() { - return false; - } - } - - for i in 0..self.hidden_layers.len().min(other.hidden_layers.len()) { - if *self.hidden_layers[i].read().unwrap() != *other.hidden_layers[i].read().unwrap() { - return false; - } - } - - for i in 0..O { - if *self.output_layer[i].read().unwrap() != *other.output_layer[i].read().unwrap() { - return false; - } - } - - true - } -} - -#[cfg(feature = "serde")] -impl From> - for NeuralNetworkTopology -{ - fn from(value: nnt_serde::NNTSerde) -> Self { - let input_layer = value - .input_layer - .into_iter() - .map(|n| Arc::new(RwLock::new(n))) - .collect::>() - .try_into() - .unwrap(); - - let hidden_layers = value - .hidden_layers - .into_iter() - .map(|n| Arc::new(RwLock::new(n))) - .collect(); - - let output_layer = value - .output_layer - .into_iter() - .map(|n| Arc::new(RwLock::new(n))) - .collect::>() - .try_into() - .unwrap(); - - NeuralNetworkTopology { - input_layer, - hidden_layers, - output_layer, - mutation_rate: value.mutation_rate, - mutation_passes: value.mutation_passes, - } - } -} - -#[cfg(feature = "crossover")] -impl CrossoverReproduction for NeuralNetworkTopology { - fn crossover(&self, other: &Self, rng: &mut impl rand::Rng) -> Self { - let input_layer = self - .input_layer - .iter() - .map(|n| Arc::new(RwLock::new(n.read().unwrap().clone()))) - .collect::>() - .try_into() - .unwrap(); - - let mut hidden_layers = - Vec::with_capacity(self.hidden_layers.len().max(other.hidden_layers.len())); - - for i in 0..hidden_layers.len() { - if rng.gen::() <= 0.5 { - if let Some(n) = self.hidden_layers.get(i) { - let mut n = n.read().unwrap().clone(); - - n.inputs - .retain(|(l, _)| input_exists(*l, &input_layer, &hidden_layers)); - hidden_layers[i] = Arc::new(RwLock::new(n)); - - continue; - } - } - - let mut n = other.hidden_layers[i].read().unwrap().clone(); - - n.inputs - .retain(|(l, _)| input_exists(*l, &input_layer, &hidden_layers)); - hidden_layers[i] = Arc::new(RwLock::new(n)); - } - - let mut output_layer: [Arc>; O] = self - .output_layer - .iter() - .map(|n| Arc::new(RwLock::new(n.read().unwrap().clone()))) - .collect::>() - .try_into() - .unwrap(); - - for (i, n) in self.output_layer.iter().enumerate() { - if rng.gen::() <= 0.5 { - let mut n = n.read().unwrap().clone(); - - n.inputs - .retain(|(l, _)| input_exists(*l, &input_layer, &hidden_layers)); - output_layer[i] = Arc::new(RwLock::new(n)); - - continue; - } - - let mut n = other.output_layer[i].read().unwrap().clone(); - - n.inputs - .retain(|(l, _)| input_exists(*l, &input_layer, &hidden_layers)); - output_layer[i] = Arc::new(RwLock::new(n)); - } - - let mut child = Self { - input_layer, - hidden_layers, - output_layer, - mutation_rate: self.mutation_rate, - mutation_passes: self.mutation_passes, - }; - - child.mutate(self.mutation_rate, rng); - - child - } -} - -#[cfg(feature = "crossover")] -fn input_exists( - loc: NeuronLocation, - input: &[Arc>; I], - hidden: &[Arc>], -) -> bool { - match loc { - NeuronLocation::Input(i) => i < input.len(), - NeuronLocation::Hidden(i) => i < hidden.len(), - NeuronLocation::Output(_) => false, - } -} - -/// A stateless version of [`Neuron`][crate::Neuron]. -#[derive(PartialEq, Debug, Clone)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct NeuronTopology { - /// The input locations and weights. - pub inputs: Vec<(NeuronLocation, f32)>, - - /// The neuron's bias. - pub bias: f32, - - /// The neuron's activation function. - pub activation: ActivationFn, -} - -impl NeuronTopology { - /// Creates a new neuron with the given input locations. - pub fn new( - inputs: Vec, - current_scope: ActivationScope, - rng: &mut impl Rng, - ) -> Self { - let reg = ACTIVATION_REGISTRY.read().unwrap(); - let activations = reg.activations_in_scope(current_scope); - - Self::new_with_activations(inputs, activations, rng) - } - - /// Takes a collection of activation functions and chooses a random one to use. - pub fn new_with_activations( - inputs: Vec, - activations: impl IntoIterator, - rng: &mut impl Rng, - ) -> Self { - let mut activations: Vec<_> = activations.into_iter().collect(); - - Self::new_with_activation( - inputs, - activations.remove(rng.gen_range(0..activations.len())), - rng, - ) - } - - /// Creates a neuron with the activation. - pub fn new_with_activation( - inputs: Vec, - activation: ActivationFn, - rng: &mut impl Rng, - ) -> Self { - let inputs = inputs - .into_iter() - .map(|i| (i, rng.gen_range(-1.0..1.0))) - .collect(); - - Self { - inputs, - bias: rng.gen(), - activation, - } - } -} - -/// A pseudo-pointer of sorts used to make structural conversions very fast and easy to write. -#[derive(Hash, Clone, Copy, Debug, Eq, PartialEq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub enum NeuronLocation { - /// Points to a neuron in the input layer at contained index. - Input(usize), - - /// Points to a neuron in the hidden layer at contained index. - Hidden(usize), - - /// Points to a neuron in the output layer at contained index. - Output(usize), -} - -impl NeuronLocation { - /// Returns `true` if it points to the input layer. Otherwise, returns `false`. - pub fn is_input(&self) -> bool { - matches!(self, Self::Input(_)) - } - - /// Returns `true` if it points to the hidden layer. Otherwise, returns `false`. - pub fn is_hidden(&self) -> bool { - matches!(self, Self::Hidden(_)) - } - - /// Returns `true` if it points to the output layer. Otherwise, returns `false`. - pub fn is_output(&self) -> bool { - matches!(self, Self::Output(_)) - } - - /// Retrieves the index value, regardless of layer. Does not consume. - pub fn unwrap(&self) -> usize { - match self { - Self::Input(i) => *i, - Self::Hidden(i) => *i, - Self::Output(i) => *i, - } - } -} diff --git a/src/topology/nnt_serde.rs b/src/topology/nnt_serde.rs deleted file mode 100644 index 14f392c..0000000 --- a/src/topology/nnt_serde.rs +++ /dev/null @@ -1,71 +0,0 @@ -use super::*; -use serde::{Deserialize, Serialize}; -use serde_big_array::BigArray; - -/// A serializable wrapper for [`NeuronTopology`]. See [`NNTSerde::from`] for conversion. -#[derive(Serialize, Deserialize)] -pub struct NNTSerde { - #[serde(with = "BigArray")] - pub(crate) input_layer: [NeuronTopology; I], - - pub(crate) hidden_layers: Vec, - - #[serde(with = "BigArray")] - pub(crate) output_layer: [NeuronTopology; O], - - pub(crate) mutation_rate: f32, - pub(crate) mutation_passes: usize, -} - -impl From<&NeuralNetworkTopology> for NNTSerde { - fn from(value: &NeuralNetworkTopology) -> Self { - let input_layer = value - .input_layer - .iter() - .map(|n| n.read().unwrap().clone()) - .collect::>() - .try_into() - .unwrap(); - - let hidden_layers = value - .hidden_layers - .iter() - .map(|n| n.read().unwrap().clone()) - .collect(); - - let output_layer = value - .output_layer - .iter() - .map(|n| n.read().unwrap().clone()) - .collect::>() - .try_into() - .unwrap(); - - Self { - input_layer, - hidden_layers, - output_layer, - mutation_rate: value.mutation_rate, - mutation_passes: value.mutation_passes, - } - } -} - -#[cfg(test)] -#[test] -fn serde() { - let mut rng = rand::thread_rng(); - let nnt = NeuralNetworkTopology::<10, 10>::new(0.1, 3, &mut rng); - let nnts = NNTSerde::from(&nnt); - - let encoded = bincode::serialize(&nnts).unwrap(); - - if let Some(_) = option_env!("TEST_CREATEFILE") { - std::fs::write("serde-test.nn", &encoded).unwrap(); - } - - let decoded: NNTSerde<10, 10> = bincode::deserialize(&encoded).unwrap(); - let nnt2: NeuralNetworkTopology<10, 10> = decoded.into(); - - dbg!(nnt, nnt2); -} From 31b6e7dd9df87824b18c2c9c850a59225d0be012 Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+HyperCodec@users.noreply.github.com> Date: Tue, 4 Feb 2025 14:03:37 +0000 Subject: [PATCH 26/27] delete topology and some examples --- examples/custom_activation.rs | 84 ---------------- examples/plot.rs | 178 ---------------------------------- src/topology/mod.rs | 0 3 files changed, 262 deletions(-) delete mode 100644 examples/custom_activation.rs delete mode 100644 examples/plot.rs delete mode 100644 src/topology/mod.rs diff --git a/examples/custom_activation.rs b/examples/custom_activation.rs deleted file mode 100644 index 7b37c02..0000000 --- a/examples/custom_activation.rs +++ /dev/null @@ -1,84 +0,0 @@ -//! An example implementation of a custom activation function. - -use neat::*; -use rand::prelude::*; - -#[derive(DivisionReproduction, RandomlyMutable, Clone)] -struct AgentDNA { - network: NeuralNetworkTopology<2, 2>, -} - -impl Prunable for AgentDNA {} - -impl GenerateRandom for AgentDNA { - fn gen_random(rng: &mut impl Rng) -> Self { - Self { - network: NeuralNetworkTopology::new(0.01, 3, rng), - } - } -} - -fn fitness(g: &AgentDNA) -> f32 { - let network: NeuralNetwork<2, 2> = NeuralNetwork::from(&g.network); - let mut fitness = 0.; - let mut rng = rand::thread_rng(); - - for _ in 0..50 { - let n = rng.gen::(); - let n2 = rng.gen::(); - - let expected = if (n + n2) / 2. >= 0.5 { 0 } else { 1 }; - - let result = network.predict([n, n2]); - network.flush_state(); - - // partial_cmp chance of returning None in this smh - let result = result.iter().max_index(); - - if result == expected { - fitness += 1.; - } else { - fitness -= 1.; - } - } - - fitness -} - -#[cfg(feature = "serde")] -fn serde_nextgen(rewards: Vec<(AgentDNA, f32)>) -> Vec { - let max = rewards - .iter() - .max_by(|(_, ra), (_, rb)| ra.total_cmp(rb)) - .unwrap(); - - let ser = NNTSerde::from(&max.0.network); - let data = serde_json::to_string_pretty(&ser).unwrap(); - std::fs::write("best-agent.json", data).expect("Failed to write to file"); - - division_pruning_nextgen(rewards) -} - -fn main() { - let sin_activation = activation_fn!(f32::sin); - register_activation(sin_activation); - - #[cfg(not(feature = "rayon"))] - let mut rng = rand::thread_rng(); - - let mut sim = GeneticSim::new( - #[cfg(not(feature = "rayon"))] - Vec::gen_random(&mut rng, 100), - #[cfg(feature = "rayon")] - Vec::gen_random(100), - fitness, - #[cfg(not(feature = "serde"))] - division_pruning_nextgen, - #[cfg(feature = "serde")] - serde_nextgen, - ); - - for _ in 0..200 { - sim.next_generation(); - } -} diff --git a/examples/plot.rs b/examples/plot.rs deleted file mode 100644 index 34fb391..0000000 --- a/examples/plot.rs +++ /dev/null @@ -1,178 +0,0 @@ -use std::{ - error::Error, - sync::{Arc, Mutex}, -}; - -use indicatif::{ProgressBar, ProgressStyle}; -use neat::*; -use plotters::prelude::*; -use rand::prelude::*; - -#[derive(RandomlyMutable, DivisionReproduction, Clone)] -struct AgentDNA { - network: NeuralNetworkTopology<2, 1>, -} - -impl Prunable for AgentDNA {} - -impl GenerateRandom for AgentDNA { - fn gen_random(rng: &mut impl Rng) -> Self { - Self { - network: NeuralNetworkTopology::new(0.01, 3, rng), - } - } -} - -fn fitness(g: &AgentDNA) -> f32 { - let network = NeuralNetwork::from(&g.network); - let mut fitness = 0.; - let mut rng = rand::thread_rng(); - - for _ in 0..100 { - let n = rng.gen::() * 10000.; - let base = rng.gen::() * 10.; - let expected = n.log(base); - - let [answer] = network.predict([n, base]); - network.flush_state(); - - fitness += 5. / (answer - expected).abs(); - } - - fitness -} - -struct PlottingNG> { - performance_stats: Arc>>, - actual_ng: F, -} - -impl> NextgenFn for PlottingNG { - fn next_gen(&self, mut fitness: Vec<(AgentDNA, f32)>) -> Vec { - // it's a bit slower because of sorting twice but I don't want to rewrite the nextgen. - fitness.sort_by(|(_, fa), (_, fb)| fa.partial_cmp(fb).unwrap()); - - let l = fitness.len(); - - let high = fitness[l - 1].1; - - let median = fitness[l / 2].1; - - let low = fitness[0].1; - - let mut ps = self.performance_stats.lock().unwrap(); - ps.push(PerformanceStats { high, median, low }); - - self.actual_ng.next_gen(fitness) - } -} - -struct PerformanceStats { - high: f32, - median: f32, - low: f32, -} - -const OUTPUT_FILE_NAME: &'static str = "fitness-plot.svg"; -const GENS: usize = 1000; - -fn main() -> Result<(), Box> { - #[cfg(not(feature = "rayon"))] - let mut rng = rand::thread_rng(); - - let performance_stats = Arc::new(Mutex::new(Vec::with_capacity(GENS))); - let ng = PlottingNG { - performance_stats: performance_stats.clone(), - actual_ng: division_pruning_nextgen, - }; - - let mut sim = GeneticSim::new( - #[cfg(not(feature = "rayon"))] - Vec::gen_random(&mut rng, 100), - #[cfg(feature = "rayon")] - Vec::gen_random(100), - fitness, - ng, - ); - - let pb = ProgressBar::new(GENS as u64) - .with_style( - ProgressStyle::with_template( - "[{elapsed_precise}] {bar:40.cyan/blue} | {msg} {pos}/{len}", - ) - .unwrap(), - ) - .with_message("gen"); - - println!("Training..."); - - for _ in 0..GENS { - sim.next_generation(); - - pb.inc(1); - } - - pb.finish(); - - // prevent `Arc::into_inner` from failing - drop(sim); - - println!("Training complete, collecting data and building chart..."); - - let root = SVGBackend::new(OUTPUT_FILE_NAME, (640, 480)).into_drawing_area(); - root.fill(&WHITE)?; - - let mut chart = ChartBuilder::on(&root) - .caption( - "agent fitness values per generation", - ("sans-serif", 50).into_font(), - ) - .margin(5) - .x_label_area_size(30) - .y_label_area_size(30) - .build_cartesian_2d(0usize..GENS, 0f32..1000.0)?; - - chart.configure_mesh().draw()?; - - let data: Vec<_> = Arc::into_inner(performance_stats) - .unwrap() - .into_inner() - .unwrap() - .into_iter() - .enumerate() - .collect(); - - let highs = data - .iter() - .map(|(i, PerformanceStats { high, .. })| (*i, *high)); - - let medians = data - .iter() - .map(|(i, PerformanceStats { median, .. })| (*i, *median)); - - let lows = data - .iter() - .map(|(i, PerformanceStats { low, .. })| (*i, *low)); - - chart - .draw_series(LineSeries::new(highs, &GREEN))? - .label("high"); - - chart - .draw_series(LineSeries::new(medians, &YELLOW))? - .label("median"); - - chart.draw_series(LineSeries::new(lows, &RED))?.label("low"); - - chart - .configure_series_labels() - .background_style(&WHITE.mix(0.8)) - .border_style(&BLACK) - .draw()?; - - root.present()?; - - println!("Complete"); - - Ok(()) -} diff --git a/src/topology/mod.rs b/src/topology/mod.rs deleted file mode 100644 index e69de29..0000000 From 0a958f95a75bdec2d8a296747dd29586245ae198 Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+HyperCodec@users.noreply.github.com> Date: Tue, 4 Feb 2025 14:08:31 +0000 Subject: [PATCH 27/27] solve clippy errors --- src/neuralnet.rs | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/neuralnet.rs b/src/neuralnet.rs index 6f5f25d..cce0d61 100644 --- a/src/neuralnet.rs +++ b/src/neuralnet.rs @@ -144,7 +144,6 @@ impl NeuralNetwork { return; } - let loc = loc.as_ref(); while !cache.is_ready(loc) { // essentially spinlocks until the dependency tasks are complete, // while letting this thread do some work on random tasks. @@ -192,7 +191,9 @@ impl NeuralNetwork { } /// Adds a connection but does not check for cyclic linkages. - /// Marked as unsafe because it could cause a hang/livelock when predicting due to cyclic linkage. + /// + /// # Safety + /// This is marked as unsafe because it could cause a hang/livelock when predicting due to cyclic linkage. /// There is no actual UB or unsafe code associated with it. pub unsafe fn add_connection_raw(&mut self, connection: Connection, weight: f32) { let a = self.get_neuron_mut(connection.from); @@ -280,7 +281,7 @@ impl NeuralNetwork { let b = self.get_neuron_mut(connection.to); b.input_count -= 1; - if b.input_count <= 0 { + if b.input_count == 0 { self.remove_neuron(connection.to); return true; } @@ -420,6 +421,7 @@ impl DivisionReproduction for NeuralNetwork CrossoverReproduction for NeuralNetwork { fn crossover(&self, other: &Self, rng: &mut impl rand::Rng) -> Self { let mut output_layer = self.output_layer.clone(); @@ -591,7 +593,10 @@ impl Neuron { } /// Tries to remove a connection from the neuron and returns the weight if it was found. - /// Marked as unsafe because it will not update the destination's [`input_count`][Neuron::input_count]. + /// + /// # Safety + /// This is marked as unsafe because it will not update the destination's [`input_count`][Neuron::input_count]. + /// Similar to [`add_connection_raw`][NeuralNetwork::add_connection_raw], this does not mean UB or anything. pub unsafe fn remove_connection(&mut self, output: impl AsRef) -> Option { let loc = *output.as_ref(); let mut i = 0; @@ -841,8 +846,7 @@ impl From<&NeuralNetwork> for NeuralNetCac let input_layer: Vec<_> = net.input_layer.par_iter().map(|n| n.into()).collect(); let input_layer = input_layer.try_into().unwrap(); - let hidden_layers: Vec<_> = net.hidden_layers.par_iter().map(|n| n.into()).collect(); - let hidden_layers = hidden_layers.try_into().unwrap(); + let hidden_layers = net.hidden_layers.par_iter().map(|n| n.into()).collect(); let output_layer: Vec<_> = net.output_layer.par_iter().map(|n| n.into()).collect(); let output_layer = output_layer.try_into().unwrap();