RatInABox2.0 - Roadmap #84
Replies: 11 comments 3 replies
-
args, not dicts 👍A helpful case study in support of args ... Most of you (I'm sure) have seen Grant Sanderson's beautiful 3blue1brown YouTube channel. Grant impressively homebrewed the Similar to here, Grant Sanderson's fork is also trying to remove them: 3b1b/manim#1932 plotting 👍💯 replotting = slow. ... if ratinabox caches plot objects, super recommend scheme we chatted about: The TaskEnvironment has a weak version of this feature -- doesn't replot everything and thus renders quickly. But it's pretty hacky in my view that the environment caches things about its agents and goals. In the long-run, it will be more maintainable to have each class in charge of caching its own plot objects rather than having to change master supervisor class's plot every time the children classes change. type hinting 👍Especially easy-to-type variables. Tools like unit testing 👍global environment 👍Possible suggestion: each RIB class could have a list of children ( Jax 🤷♂️No strong opinions. Leaning partial Jax if the penalty for binary-op/shuttling numpy to a CPU jax.device is low. |
Beta Was this translation helpful? Give feedback.
-
|
Sounds like a great idea overall for the longevity of the package! I definitely agree for the args instead of dicts, type hinting and unit testing. For global environment, if the cascading update is implemented, I would suggest having a I would suggest an additional section: |
Beta Was this translation helpful? Give feedback.
-
|
Great comments, thanks guys. @SynapticSage 3B1B advice heeded! @colleenjg you're right this could be more modular, for example |
Beta Was this translation helpful? Give feedback.
-
|
These all sound like great changes for RAIB 2.0, and I agree w all of the comments from @SynapticSage and @colleenjg :) I'm a particularly big fan of the global environment updating, as this seems much more concise. My only concern is whether this would slow down updates for really long simulations (like the ones I have been running, e.g. @ 30 Hz x 31 sessions x 40 min/session). It might be ideal to perform more selective updates and skip others if they are going to be static using some sort of argument in As far as Jax compatibility, I would be very much for this if it can actually speed things up for the heavier computations and long simulations, but as you point out it might not save compute time if large arrays are being converted often. I believe it would be worth some case testing in a couple of large simulations before ruling this out. |
Beta Was this translation helpful? Give feedback.
-
|
Thanks for the feedback, closing for now. |
Beta Was this translation helpful? Give feedback.
-
|
One thing that just occurred to me, which could be considered: Only passing In typical use cases, to my knowledge, passing both should be redundant, as you can access the figure with |
Beta Was this translation helpful? Give feedback.
-
|
Agreed and added to the list. It's essentially redundant and only add bloat |
Beta Was this translation helpful? Give feedback.
-
|
If you add As the primary maintainer of the Fedora Linux package for this project, I’m not sure if packaging https://github.com/google/jax would be feasible for us or not. While it does look like |
Beta Was this translation helpful? Give feedback.
-
|
@musicinmybrain thanks for your feedback - that's ok, I doubt we'd go full |
Beta Was this translation helpful? Give feedback.
-
|
We should consider type hinting for Riab 2.0. (https://mypy.readthedocs.io/en/stable/cheat_sheet_py3.html) This will enable users to do easier lookups and help with the autocompletion of code for anyone coding in an integrated development environment. Will mostly involve 2 things:-
|
Beta Was this translation helpful? Give feedback.
-
|
@mehulrastogi just thinking...do you think a more scalable way to support dynamic environments would be for all environmental objects to be their own class. I can imagine a This is as opposed to the solution I'd previously imagined where the environment itself store a "state" dictionary which can be updated. The nice thing about this new proposal is that My main concern is that it would create memory issues storing so much more data but to be honest I'm not sure it would be much more data than Neurons already save without problems. We could also image some memory-clever solution since 99% of the time walls won't move. |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I've begun to think about 2.0. The reason is that there are are certainly a couple of choices I made early on in development which weren't optimal. Now could be a good time to fix these as the community is growing but still small enough it won't be super disruptive. Also fixing them will make it easier to maintain RiaB in the long run.
I'm opening this issue to get community thoughts on this. @SynapticSage @colleenjg @jquinnlee @mehulrastogi you're some of the most active users I know fairly well so I'm tagging you to get your input (if you have any), but anyone can chip in here. Here's my thoughts:
Essential and backwards incompatible changes (do first):
Neuronsclasses in one.pyfile.update(): Given, now,Environmentsknow about theirAgentsandAgentsknow about theirNeuronswe could have just one update function inEnvwhich cascades through else thing else. Cleaner?dev-->mainEnvironmentstores the global clock. This just makes sense imo.drift_velocitykwarg. Maybe insteadAgents can have apolicy()method which returns a drift - this would default to the random motion policy, unifying that too. Just something to consider.get_state()repeats logic each time. Seems better to have a unique.forward()(or maybe calledcalculate_firing_rate()) method for each class which receives arrays of positions and head directions etc and a shared.get_state()which lives inNeuronswhich calls.forward(). I'm returning to +1 this idea, it makes a lot more sense. Also, instead ofget_state()we can have numerous such asget_agent_firing_rate(),get_rate_map(),get_angular_tuning_curve()etc. These "get"-functions should return not just lists of the firing rates but also the lists of the respective coordinates (we could maybe usexarrayfor this but I don't like the extra dependency).EnvironmentEntityclass could be made which can be added to anEnvironmentwhich will then update each entity at each time step. Entities will have their own.rendermethods so Env can loop over them and plot them too. This could be used to flexibly buildTeleports,Doors,Keys, etc. etc. for more dynamic environments.Other essential changes
update()perhaps adding into new agent/neuron/env specific utils scripts.Env.historydictionary. Then, when plotting / animating the environment we can pass in a time argument and the correct state can be retrieved and plotted. The state of the environment only appends to history whenever it changes (e.g. a setter is called).plot_environment()it can be passed afiganaxand a new object which is a list/dict of plot objects,Rwhich are allmatplotlib.Artistsalready existing on the figure. The environment can store an equivalent list of plot objects and whenever this changes (e.g. a wall is added or an object is moved etc.) this change is logged then plotting can (i) get the list of plot objects corresponding to the correct time and (ii) compared it to the passed list, if they aren't equal then repot the env, otherwise don't bother. Something like that.Environments have anEnv.historydictionary storing the full "state" of the environment (all object locations, walls, boundaries, etc.). ThenEnv.plot_environment()takes a time argument and find the state of the at that time and plots that.Agent.update()as shown in the paper this is a significant bottleneck.animate_API withplot_API just with a few extra kwargs.axnotfigto figure plotting functions. This may throw up some things but likely minor.utils.pyinto separate ones for theAgentpackage,Neuronspackage andEnvpackage and maybe also amisc.RatInABox/RatInABoxnotTomGeorge1234/RatInABoxRatInABox/RatInABox_RL** package containing all the RL stuff (Actor,Critic,ValueNeuron,TDError,TaskEnvetc.)IntermediateNeuronssubclass for neurons which aren't "fundamental" but take other neurons as inputs. Current examples areFeedForwardLayerandNeuralNetworkNeuronsDynamicNeuronssubclass for neurons which aren't static i.e. you can't callNeurons.plot_rate_map()because they actually depend on the past history. Examples includeTDErrorNeurons(to be made) or anything with recurrency.SmoothRandomFeatureNeuronsjust some spatially tuned but random neurons. Users just provide a length scale. Would be useful for a lot of feature learning studies. Probably something like a gaussian process underlying these neurons.argsmorekwargsSome of the functions (in particular plotting) have quite bloated argument lists. I think it would be better to remove some of these and allow them to be hidden in**kwargsthen defined at the top of the functionarg = kwargs.get("arg_name",default_val). This is backwards compatible, cleans up the doc strings so more readable, and we can use this to expose any/all free parameters (even ones which weren't before hand an argument) for greater flexibility. I have done this to theAgent.plot_trajectory()np.histogram2d. See issue Improving how RatInABox estimates and facilitates rate map estimation from observed data #125Things to consider
Neuronsshould followtorch.nn.moduleAPI - this would make more efficient the evaluation of complex feedforward graphs which currently happens in a backwards manner. This might require renaming the.get_state()method with.forward(). Need to think more about thisnp-->jnpeverywhere.I'm not a software guy so @SynapticSage @mehulrastogi feel free to give high level comments about best way to go forward. A
Beta Was this translation helpful? Give feedback.
All reactions