-
Notifications
You must be signed in to change notification settings - Fork 430
[BugFix] Fix agent_dim in multiagent nets & account for neg dims #3290
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3290
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
matteobettini
approved these changes
Jan 6, 2026
|
| Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
|---|---|---|---|---|---|
| test_tensor_to_bytestream_speed[pickle] | 82.6993μs | 81.3329μs | 12.2951 KOps/s | 12.2281 KOps/s | |
| test_tensor_to_bytestream_speed[torch.save] | 0.1479ms | 0.1473ms | 6.7904 KOps/s | 7.0998 KOps/s | |
| test_tensor_to_bytestream_speed[untyped_storage] | 0.1266s | 0.1262s | 7.9215 Ops/s | 8.0767 Ops/s | |
| test_tensor_to_bytestream_speed[numpy] | 2.7284μs | 2.7182μs | 367.8900 KOps/s | 365.3907 KOps/s | |
| test_tensor_to_bytestream_speed[safetensors] | 42.2566μs | 40.7673μs | 24.5295 KOps/s | 25.4719 KOps/s | |
| test_simple | 0.5485s | 0.5476s | 1.8263 Ops/s | 1.7354 Ops/s | |
| test_transformed | 1.1276s | 1.1272s | 0.8872 Ops/s | 0.8703 Ops/s | |
| test_serial | 1.6723s | 1.6699s | 0.5988 Ops/s | 0.5889 Ops/s | |
| test_parallel | 1.0880s | 1.0845s | 0.9221 Ops/s | 0.8756 Ops/s | |
| test_step_mdp_speed[True-True-True-True-True] | 0.2036ms | 43.8040μs | 22.8290 KOps/s | 22.4690 KOps/s | |
| test_step_mdp_speed[True-True-True-True-False] | 61.6810μs | 25.1850μs | 39.7062 KOps/s | 39.0643 KOps/s | |
| test_step_mdp_speed[True-True-True-False-True] | 61.6620μs | 25.3757μs | 39.4078 KOps/s | 39.3472 KOps/s | |
| test_step_mdp_speed[True-True-True-False-False] | 46.0910μs | 13.8601μs | 72.1498 KOps/s | 71.7364 KOps/s | |
| test_step_mdp_speed[True-True-False-True-True] | 93.9020μs | 48.1269μs | 20.7784 KOps/s | 20.8005 KOps/s | |
| test_step_mdp_speed[True-True-False-True-False] | 67.8320μs | 27.9533μs | 35.7739 KOps/s | 35.2163 KOps/s | |
| test_step_mdp_speed[True-True-False-False-True] | 60.1820μs | 28.2058μs | 35.4538 KOps/s | 35.4637 KOps/s | |
| test_step_mdp_speed[True-True-False-False-False] | 53.9420μs | 16.7677μs | 59.6383 KOps/s | 58.7384 KOps/s | |
| test_step_mdp_speed[True-False-True-True-True] | 85.9520μs | 51.4973μs | 19.4185 KOps/s | 19.3835 KOps/s | |
| test_step_mdp_speed[True-False-True-True-False] | 66.4820μs | 30.6820μs | 32.5924 KOps/s | 31.8085 KOps/s | |
| test_step_mdp_speed[True-False-True-False-True] | 63.6820μs | 28.0254μs | 35.6819 KOps/s | 35.6467 KOps/s | |
| test_step_mdp_speed[True-False-True-False-False] | 48.7510μs | 16.5329μs | 60.4853 KOps/s | 59.1919 KOps/s | |
| test_step_mdp_speed[True-False-False-True-True] | 94.0530μs | 52.9396μs | 18.8894 KOps/s | 18.6762 KOps/s | |
| test_step_mdp_speed[True-False-False-True-False] | 68.2610μs | 33.5199μs | 29.8330 KOps/s | 29.4835 KOps/s | |
| test_step_mdp_speed[True-False-False-False-True] | 97.5830μs | 30.6481μs | 32.6284 KOps/s | 33.0901 KOps/s | |
| test_step_mdp_speed[True-False-False-False-False] | 87.8920μs | 19.1335μs | 52.2642 KOps/s | 51.6373 KOps/s | |
| test_step_mdp_speed[False-True-True-True-True] | 89.6420μs | 51.5886μs | 19.3841 KOps/s | 19.7579 KOps/s | |
| test_step_mdp_speed[False-True-True-True-False] | 83.0620μs | 30.5377μs | 32.7464 KOps/s | 32.6237 KOps/s | |
| test_step_mdp_speed[False-True-True-False-True] | 2.3529ms | 32.6957μs | 30.5851 KOps/s | 31.2007 KOps/s | |
| test_step_mdp_speed[False-True-True-False-False] | 50.6410μs | 18.1809μs | 55.0029 KOps/s | 54.1153 KOps/s | |
| test_step_mdp_speed[False-True-False-True-True] | 96.0930μs | 54.1152μs | 18.4791 KOps/s | 18.6895 KOps/s | |
| test_step_mdp_speed[False-True-False-True-False] | 65.1910μs | 33.4891μs | 29.8605 KOps/s | 29.7964 KOps/s | |
| test_step_mdp_speed[False-True-False-False-True] | 71.8720μs | 34.3845μs | 29.0829 KOps/s | 29.4168 KOps/s | |
| test_step_mdp_speed[False-True-False-False-False] | 51.8310μs | 21.2461μs | 47.0676 KOps/s | 47.4042 KOps/s | |
| test_step_mdp_speed[False-False-True-True-True] | 85.0820μs | 56.1101μs | 17.8221 KOps/s | 17.8811 KOps/s | |
| test_step_mdp_speed[False-False-True-True-False] | 70.8810μs | 36.4616μs | 27.4261 KOps/s | 26.9593 KOps/s | |
| test_step_mdp_speed[False-False-True-False-True] | 69.5010μs | 34.8484μs | 28.6957 KOps/s | 29.0994 KOps/s | |
| test_step_mdp_speed[False-False-True-False-False] | 52.6510μs | 21.1085μs | 47.3742 KOps/s | 45.7830 KOps/s | |
| test_step_mdp_speed[False-False-False-True-True] | 98.5020μs | 56.7653μs | 17.6164 KOps/s | 17.0015 KOps/s | |
| test_step_mdp_speed[False-False-False-True-False] | 69.7120μs | 37.8779μs | 26.4006 KOps/s | 25.6243 KOps/s | |
| test_step_mdp_speed[False-False-False-False-True] | 78.0520μs | 36.9339μs | 27.0754 KOps/s | 27.5003 KOps/s | |
| test_step_mdp_speed[False-False-False-False-False] | 56.6620μs | 23.6655μs | 42.2556 KOps/s | 42.8864 KOps/s | |
| test_non_tensor_env_rollout_speed[1000-single-True] | 0.8780s | 0.7777s | 1.2858 Ops/s | 1.2978 Ops/s | |
| test_non_tensor_env_rollout_speed[1000-single-False] | 0.7430s | 0.6431s | 1.5550 Ops/s | 1.5756 Ops/s | |
| test_non_tensor_env_rollout_speed[1000-serial-no-buffers-True] | 1.7738s | 1.6892s | 0.5920 Ops/s | 0.5916 Ops/s | |
| test_non_tensor_env_rollout_speed[1000-serial-no-buffers-False] | 1.5476s | 1.4676s | 0.6814 Ops/s | 0.6839 Ops/s | |
| test_non_tensor_env_rollout_speed[1000-serial-buffers-True] | 2.0268s | 1.9401s | 0.5154 Ops/s | 0.5162 Ops/s | |
| test_non_tensor_env_rollout_speed[1000-serial-buffers-False] | 1.7961s | 1.7110s | 0.5845 Ops/s | 0.5869 Ops/s | |
| test_non_tensor_env_rollout_speed[1000-parallel-no-buffers-True] | 4.7072s | 4.6382s | 0.2156 Ops/s | 0.2132 Ops/s | |
| test_non_tensor_env_rollout_speed[1000-parallel-no-buffers-False] | 4.5536s | 4.4390s | 0.2253 Ops/s | 0.2238 Ops/s | |
| test_non_tensor_env_rollout_speed[1000-parallel-buffers-True] | 2.0641s | 1.9529s | 0.5121 Ops/s | 0.5089 Ops/s | |
| test_non_tensor_env_rollout_speed[1000-parallel-buffers-False] | 1.8310s | 1.6935s | 0.5905 Ops/s | 0.5976 Ops/s | |
| test_values[generalized_advantage_estimate-True-True] | 10.7394ms | 10.5599ms | 94.6982 Ops/s | 95.7692 Ops/s | |
| test_values[vec_generalized_advantage_estimate-True-True] | 15.1273ms | 11.1774ms | 89.4663 Ops/s | 88.0892 Ops/s | |
| test_values[td0_return_estimate-False-False] | 0.2249ms | 0.1292ms | 7.7407 KOps/s | 7.8396 KOps/s | |
| test_values[td1_return_estimate-False-False] | 27.8360ms | 27.5819ms | 36.2556 Ops/s | 36.2726 Ops/s | |
| test_values[vec_td1_return_estimate-False-False] | 12.1282ms | 11.3185ms | 88.3511 Ops/s | 88.3196 Ops/s | |
| test_values[td_lambda_return_estimate-True-False] | 41.2380ms | 40.7713ms | 24.5270 Ops/s | 24.4650 Ops/s | |
| test_values[vec_td_lambda_return_estimate-True-False] | 11.4381ms | 11.2298ms | 89.0489 Ops/s | 88.9299 Ops/s | |
| test_gae_speed[generalized_advantage_estimate-False-1-512] | 9.4760ms | 9.3765ms | 106.6492 Ops/s | 107.0924 Ops/s | |
| test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 1.9990ms | 1.5327ms | 652.4530 Ops/s | 677.4207 Ops/s | |
| test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.4878ms | 0.4191ms | 2.3860 KOps/s | 2.3320 KOps/s | |
| test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 29.7088ms | 23.5018ms | 42.5499 Ops/s | 42.4697 Ops/s | |
| test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 1.9434ms | 1.7663ms | 566.1697 Ops/s | 580.2906 Ops/s | |
| test_dqn_speed[False-None] | 1.7762ms | 1.4160ms | 706.2111 Ops/s | 710.7182 Ops/s | |
| test_dqn_speed[False-backward] | 2.0579ms | 1.9671ms | 508.3583 Ops/s | 509.9209 Ops/s | |
| test_dqn_speed[True-None] | 0.8577ms | 0.5265ms | 1.8994 KOps/s | 1.8689 KOps/s | |
| test_dqn_speed[True-backward] | 1.1022ms | 0.9848ms | 1.0154 KOps/s | 847.3147 Ops/s | |
| test_dqn_speed[reduce-overhead-None] | 0.6457ms | 0.5313ms | 1.8823 KOps/s | 1.8308 KOps/s | |
| test_dqn_speed[reduce-overhead-backward] | 1.0139ms | 0.9686ms | 1.0325 KOps/s | 860.8465 Ops/s | |
| test_ddpg_speed[False-None] | 3.1107ms | 2.8451ms | 351.4765 Ops/s | 326.6155 Ops/s | |
| test_ddpg_speed[False-backward] | 4.1568ms | 4.0774ms | 245.2547 Ops/s | 244.4846 Ops/s | |
| test_ddpg_speed[True-None] | 1.8136ms | 1.4100ms | 709.2230 Ops/s | 716.4580 Ops/s | |
| test_ddpg_speed[True-backward] | 2.4321ms | 2.3871ms | 418.9241 Ops/s | 389.3780 Ops/s | |
| test_ddpg_speed[reduce-overhead-None] | 1.5388ms | 1.3876ms | 720.6597 Ops/s | 714.1381 Ops/s | |
| test_ddpg_speed[reduce-overhead-backward] | 2.3917ms | 2.3471ms | 426.0571 Ops/s | 422.4037 Ops/s | |
| test_sac_speed[False-None] | 8.4209ms | 7.9490ms | 125.8022 Ops/s | 124.5684 Ops/s | |
| test_sac_speed[False-backward] | 11.6013ms | 11.2531ms | 88.8647 Ops/s | 88.1007 Ops/s | |
| test_sac_speed[True-None] | 2.5639ms | 2.1280ms | 469.9330 Ops/s | 469.4503 Ops/s | |
| test_sac_speed[True-backward] | 4.1292ms | 3.9905ms | 250.5946 Ops/s | 237.0799 Ops/s | |
| test_sac_speed[reduce-overhead-None] | 2.5514ms | 2.1263ms | 470.2908 Ops/s | 448.2523 Ops/s | |
| test_sac_speed[reduce-overhead-backward] | 4.1263ms | 4.0087ms | 249.4594 Ops/s | 224.7700 Ops/s | |
| test_redq_speed[False-None] | 10.8531ms | 10.3315ms | 96.7918 Ops/s | 96.2031 Ops/s | |
| test_redq_speed[False-backward] | 18.7126ms | 17.9227ms | 55.7952 Ops/s | 55.8147 Ops/s | |
| test_redq_speed[True-None] | 4.6012ms | 4.3906ms | 227.7570 Ops/s | 225.7943 Ops/s | |
| test_redq_speed[True-backward] | 10.0067ms | 9.8448ms | 101.5764 Ops/s | 100.3225 Ops/s | |
| test_redq_speed[reduce-overhead-None] | 4.4605ms | 4.3061ms | 232.2280 Ops/s | 221.7998 Ops/s | |
| test_redq_speed[reduce-overhead-backward] | 10.0442ms | 9.8419ms | 101.6062 Ops/s | 101.2865 Ops/s | |
| test_redq_deprec_speed[False-None] | 11.5150ms | 11.0978ms | 90.1082 Ops/s | 91.1731 Ops/s | |
| test_redq_deprec_speed[False-backward] | 16.3947ms | 16.0075ms | 62.4707 Ops/s | 63.4763 Ops/s | |
| test_redq_deprec_speed[True-None] | 5.8107ms | 3.6381ms | 274.8701 Ops/s | 262.0631 Ops/s | |
| test_redq_deprec_speed[True-backward] | 7.8009ms | 7.6029ms | 131.5279 Ops/s | 129.0063 Ops/s | |
| test_redq_deprec_speed[reduce-overhead-None] | 3.8741ms | 3.6289ms | 275.5663 Ops/s | 275.7883 Ops/s | |
| test_redq_deprec_speed[reduce-overhead-backward] | 7.8440ms | 7.6192ms | 131.2469 Ops/s | 131.2608 Ops/s | |
| test_td3_speed[False-None] | 8.1476ms | 8.0450ms | 124.3008 Ops/s | 123.5412 Ops/s | |
| test_td3_speed[False-backward] | 11.5054ms | 10.9708ms | 91.1508 Ops/s | 91.8971 Ops/s | |
| test_td3_speed[True-None] | 1.9952ms | 1.8144ms | 551.1506 Ops/s | 546.7887 Ops/s | |
| test_td3_speed[True-backward] | 3.7673ms | 3.6655ms | 272.8121 Ops/s | 275.2733 Ops/s | |
| test_td3_speed[reduce-overhead-None] | 1.8213ms | 1.7812ms | 561.4288 Ops/s | 556.8987 Ops/s | |
| test_td3_speed[reduce-overhead-backward] | 3.7726ms | 3.6480ms | 274.1230 Ops/s | 228.8047 Ops/s | |
| test_cql_speed[False-None] | 28.9061ms | 26.0119ms | 38.4440 Ops/s | 38.0361 Ops/s | |
| test_cql_speed[False-backward] | 38.1885ms | 35.3177ms | 28.3144 Ops/s | 27.9663 Ops/s | |
| test_cql_speed[True-None] | 15.6717ms | 12.8235ms | 77.9819 Ops/s | 78.5781 Ops/s | |
| test_cql_speed[True-backward] | 18.5502ms | 18.0977ms | 55.2558 Ops/s | 55.2775 Ops/s | |
| test_cql_speed[reduce-overhead-None] | 12.8400ms | 12.3833ms | 80.7542 Ops/s | 80.3791 Ops/s | |
| test_cql_speed[reduce-overhead-backward] | 18.3252ms | 17.9797ms | 55.6181 Ops/s | 55.4551 Ops/s | |
| test_a2c_speed[False-None] | 5.8407ms | 5.3736ms | 186.0953 Ops/s | 181.2461 Ops/s | |
| test_a2c_speed[False-backward] | 12.0596ms | 11.8048ms | 84.7111 Ops/s | 84.2712 Ops/s | |
| test_a2c_speed[True-None] | 4.0373ms | 3.6600ms | 273.2212 Ops/s | 265.9268 Ops/s | |
| test_a2c_speed[True-backward] | 8.9322ms | 8.5745ms | 116.6253 Ops/s | 108.3167 Ops/s | |
| test_a2c_speed[reduce-overhead-None] | 3.8546ms | 3.6867ms | 271.2420 Ops/s | 272.5459 Ops/s | |
| test_a2c_speed[reduce-overhead-backward] | 10.5477ms | 8.6758ms | 115.2626 Ops/s | 115.1688 Ops/s | |
| test_ppo_speed[False-None] | 6.0304ms | 5.8357ms | 171.3578 Ops/s | 169.8127 Ops/s | |
| test_ppo_speed[False-backward] | 12.7064ms | 12.4863ms | 80.0878 Ops/s | 79.7417 Ops/s | |
| test_ppo_speed[True-None] | 3.9126ms | 3.6018ms | 277.6388 Ops/s | 277.4255 Ops/s | |
| test_ppo_speed[True-backward] | 8.7381ms | 8.2887ms | 120.6465 Ops/s | 119.6655 Ops/s | |
| test_ppo_speed[reduce-overhead-None] | 3.9107ms | 3.5881ms | 278.6952 Ops/s | 278.4202 Ops/s | |
| test_ppo_speed[reduce-overhead-backward] | 8.8976ms | 8.6684ms | 115.3620 Ops/s | 116.5420 Ops/s | |
| test_reinforce_speed[False-None] | 4.7733ms | 4.5076ms | 221.8463 Ops/s | 217.0544 Ops/s | |
| test_reinforce_speed[False-backward] | 7.7021ms | 7.3043ms | 136.9061 Ops/s | 134.4710 Ops/s | |
| test_reinforce_speed[True-None] | 3.1806ms | 2.8625ms | 349.3480 Ops/s | 331.3837 Ops/s | |
| test_reinforce_speed[True-backward] | 7.7931ms | 7.5869ms | 131.8069 Ops/s | 129.6269 Ops/s | |
| test_reinforce_speed[reduce-overhead-None] | 3.1075ms | 2.8386ms | 352.2845 Ops/s | 345.4863 Ops/s | |
| test_reinforce_speed[reduce-overhead-backward] | 8.1992ms | 7.8173ms | 127.9217 Ops/s | 124.7871 Ops/s | |
| test_iql_speed[False-None] | 24.5914ms | 19.7995ms | 50.5064 Ops/s | 49.4755 Ops/s | |
| test_iql_speed[False-backward] | 36.0014ms | 30.2744ms | 33.0312 Ops/s | 32.4117 Ops/s | |
| test_iql_speed[True-None] | 8.8305ms | 8.4463ms | 118.3957 Ops/s | 114.5598 Ops/s | |
| test_iql_speed[True-backward] | 16.9277ms | 16.5724ms | 60.3415 Ops/s | 59.9428 Ops/s | |
| test_iql_speed[reduce-overhead-None] | 9.1959ms | 8.5006ms | 117.6394 Ops/s | 116.1037 Ops/s | |
| test_iql_speed[reduce-overhead-backward] | 17.4595ms | 16.9210ms | 59.0983 Ops/s | 56.5836 Ops/s | |
| test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 8.1881ms | 6.1336ms | 163.0365 Ops/s | 162.2588 Ops/s | |
| test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.6909ms | 0.3410ms | 2.9328 KOps/s | 3.2294 KOps/s | |
| test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.6093ms | 0.3393ms | 2.9473 KOps/s | 3.3287 KOps/s | |
| test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 6.2284ms | 5.8578ms | 170.7134 Ops/s | 171.0752 Ops/s | |
| test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 2.6293ms | 0.3311ms | 3.0204 KOps/s | 3.1868 KOps/s | |
| test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.6112ms | 0.3322ms | 3.0101 KOps/s | 3.0269 KOps/s | |
| test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] | 1.6970ms | 1.4125ms | 707.9707 Ops/s | 711.5486 Ops/s | |
| test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] | 1.6334ms | 1.3340ms | 749.6367 Ops/s | 843.7122 Ops/s | |
| test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 9.4421ms | 6.1037ms | 163.8351 Ops/s | 167.6755 Ops/s | |
| test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 1.2769ms | 0.5320ms | 1.8796 KOps/s | 2.1045 KOps/s | |
| test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.6967ms | 0.4723ms | 2.1171 KOps/s | 2.3965 KOps/s | |
| test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 5.8575ms | 5.7613ms | 173.5719 Ops/s | 169.7565 Ops/s | |
| test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.9526ms | 0.3626ms | 2.7575 KOps/s | 3.5026 KOps/s | |
| test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.5270ms | 0.2923ms | 3.4208 KOps/s | 3.6847 KOps/s | |
| test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 5.9676ms | 5.7076ms | 175.2037 Ops/s | 172.7415 Ops/s | |
| test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 0.9596ms | 0.3638ms | 2.7486 KOps/s | 3.5596 KOps/s | |
| test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.5347ms | 0.3176ms | 3.1485 KOps/s | 3.7917 KOps/s | |
| test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 6.1933ms | 6.0519ms | 165.2367 Ops/s | 166.2220 Ops/s | |
| test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 2.6554ms | 0.5438ms | 1.8390 KOps/s | 2.0768 KOps/s | |
| test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.6710ms | 0.5010ms | 1.9960 KOps/s | 2.1033 KOps/s | |
| test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 6.4677ms | 5.0325ms | 198.7097 Ops/s | 196.4406 Ops/s | |
| test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 6.2891ms | 2.3416ms | 427.0619 Ops/s | 431.2586 Ops/s | |
| test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 8.3182ms | 1.2467ms | 802.1367 Ops/s | 816.4294 Ops/s | |
| test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 7.5763ms | 5.1038ms | 195.9333 Ops/s | 52.7070 Ops/s | |
| test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 7.8427ms | 2.3315ms | 428.9066 Ops/s | 473.7192 Ops/s | |
| test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 7.6212ms | 1.2290ms | 813.6814 Ops/s | 887.2695 Ops/s | |
| test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 0.6287s | 17.7317ms | 56.3962 Ops/s | 188.3774 Ops/s | |
| test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 3.9758ms | 1.8928ms | 528.3244 Ops/s | 439.9702 Ops/s | |
| test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 8.1131ms | 1.3656ms | 732.2834 Ops/s | 932.8560 Ops/s | |
| test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-True] | 36.5222ms | 34.3862ms | 29.0814 Ops/s | 28.4421 Ops/s | |
| test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-False] | 19.3565ms | 17.6533ms | 56.6467 Ops/s | 56.0167 Ops/s | |
| test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-True] | 37.5454ms | 35.2636ms | 28.3579 Ops/s | 27.6332 Ops/s | |
| test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-False] | 19.3083ms | 17.7675ms | 56.2827 Ops/s | 55.2427 Ops/s | |
| test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-True] | 39.2331ms | 36.9674ms | 27.0509 Ops/s | 26.4309 Ops/s | |
| test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-False] | 20.8027ms | 19.3577ms | 51.6590 Ops/s | 51.6852 Ops/s |
Collaborator
Author
|
Thanks @matteobettini I addressed your comment! |
Contributor
|
Amazing! Thanks for this <3 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
bug
Something isn't working
CLA Signed
This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Changes Made
1.
torchrl/modules/models/multiagent.pyFixed the
forwardmethod inMultiAgentNetBaseto properly respect theagent_dimparameter:agent_dimto a positive index based on input dimensionsout_dimsparameter2.
test/test_modules.pyAdded a non-regression test
test_multiagent_custom_agent_dim:MultiAgentNetBaseagent_dimvalues (1 and -3) with bothshare_params=Trueandshare_params=FalseThe test creates a custom
MultiAgentNetBasesubclass (similar to the user's reproduction case) with an MLP that processes inputs with shape(batch, n_agents, seq_len, obs_dim)where agents are at dimension 1, and verifies the output correctly maintains agents at dimension 1.