-
-
Notifications
You must be signed in to change notification settings - Fork 7
Fix issue 394 and update info #481
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix issue 394 and update info #481
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. Summary by CodeRabbit
WalkthroughAdds a comprehensive reinforcement‑learning subsystem: new RL interfaces and bases, many agent implementations (value‑based, policy‑gradient, actor‑critic, model‑based, offline, multi‑agent, bandits, tabular, planning, eligibility traces, Monte Carlo), replay buffers (uniform + prioritized), trajectory/context types, options for many algorithms, a CartPole environment, PredictionModelBuilder RL integration, RL README, unit tests, and compatibility/initializer fixes. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant Builder as PredictionModelBuilder
participant Env as IEnvironment<T>
participant Agent as IRLAgent<T>
User->>Builder: ConfigureEnvironment(env)
User->>Builder: BuildAsync(episodes, verbose)
loop each episode
Builder->>Env: Reset()
Env-->>Builder: state
loop until done
Builder->>Agent: SelectAction(state, training=true)
Agent-->>Builder: action
Builder->>Env: Step(action)
Env-->>Builder: nextState, reward, done, info
Builder->>Agent: StoreExperience(state, action, reward, nextState, done)
Builder->>Agent: Train()
Agent-->>Builder: metrics
Builder->>Builder: log/accumulate metrics
Builder->>Builder: state = nextState
end
end
Builder-->>User: PredictionModelResult (model/parameters, metrics)
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120+ minutes
Possibly related PRs
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning, 2 inconclusive)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This pull request adds foundational reinforcement learning infrastructure to AiDotNet, including a base agent class, comprehensive architecture documentation, and unit tests for planned components.
Key Changes
- Introduces
ReinforcementLearningAgentBase<T>abstract class providing common RL agent functionality - Adds comprehensive architecture documentation outlining the planned RL module design with 20+ algorithms
- Creates unit tests for
UniformReplayBuffer,EpsilonGreedyPolicy, andCartPoleEnvironment(implementations not yet added)
Reviewed Changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs | Defines abstract base class for RL agents with common methods and a configuration options class |
| src/ReinforcementLearning/ARCHITECTURE.md | Comprehensive design document detailing RL integration patterns, algorithms, and implementation roadmap |
| tests/AiDotNet.Tests/UnitTests/ReinforcementLearning/UniformReplayBufferTests.cs | Unit tests for replay buffer functionality (implementation pending) |
| tests/AiDotNet.Tests/UnitTests/ReinforcementLearning/EpsilonGreedyPolicyTests.cs | Unit tests for epsilon-greedy exploration policy (implementation pending) |
| tests/AiDotNet.Tests/UnitTests/ReinforcementLearning/CartPoleEnvironmentTests.cs | Unit tests for CartPole environment (implementation pending) |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs
Outdated
Show resolved
Hide resolved
src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs
Outdated
Show resolved
Hide resolved
tests/AiDotNet.Tests/UnitTests/ReinforcementLearning/CartPoleEnvironmentTests.cs
Show resolved
Hide resolved
tests/AiDotNet.Tests/UnitTests/ReinforcementLearning/CartPoleEnvironmentTests.cs
Show resolved
Hide resolved
tests/AiDotNet.Tests/UnitTests/ReinforcementLearning/CartPoleEnvironmentTests.cs
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 7
♻️ Duplicate comments (1)
src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs (1)
366-366: Fix syntax error: missing semicolon.Missing semicolon after
initis causing compilation failure (CS8180).Apply this diff:
- public double EpsilonEnd { get; init} = 0.01; + public double EpsilonEnd { get; init; } = 0.01;
🧹 Nitpick comments (3)
src/ReinforcementLearning/Agents/DQN/DQNOptions.cs (1)
134-153: Remove unused variable.Line 136 declares
numOpsbut never uses it. Consider removing this unused variable.Apply this diff:
public static DQNOptions<T> Default(int stateSize, int actionSize, T learningRate, T discountFactor) { - var numOps = NumericOperations<T>.Instance; return new DQNOptions<T> {src/ReinforcementLearning/Agents/PPO/PPOAgent.cs (1)
705-757: Consider extracting common helper methods.The helper methods
ArgMax,Max,ComputeMean, andComputeStdare likely duplicated in other agent implementations (e.g., DQNAgent.cs). Consider extracting these to a shared utility class to reduce code duplication.src/ReinforcementLearning/Agents/DQN/DQNAgent.cs (1)
167-212: Consider batching parameter updates for efficiency.The training loop updates network parameters after each individual experience (lines 197-211), resulting in
BatchSizeparameter updates per training step. Standard DQN implementations accumulate gradients across the batch and perform a single parameter update, which is more efficient and stable.Consider refactoring to accumulate gradients and update once per batch:
foreach (var experience in batch) { // ... compute loss and gradients ... var gradients = LossFunction.ComputeGradient(currentQValues, targetQValues); _qNetwork.Backward(gradients); - - // Update weights using learning rate - var parameters = _qNetwork.GetFlattenedParameters(); - var gradientVector = _qNetwork.GetFlattenedGradients(); - - for (int i = 0; i < parameters.Length; i++) - { - var update = NumOps.Multiply(LearningRate, gradientVector[i]); - parameters[i] = NumOps.Subtract(parameters[i], update); - } - - _qNetwork.UpdateParameters(parameters); } + +// Update weights once after processing entire batch +var parameters = _qNetwork.GetFlattenedParameters(); +var gradientVector = _qNetwork.GetFlattenedGradients(); + +for (int i = 0; i < parameters.Length; i++) +{ + var update = NumOps.Multiply(LearningRate, gradientVector[i]); + parameters[i] = NumOps.Subtract(parameters[i], update); +} + +_qNetwork.UpdateParameters(parameters);
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (20)
src/Enums/ModelType.cs(1 hunks)src/PredictionModelBuilder.cs(3 hunks)src/ReinforcementLearning/Agents/DQN/DQNAgent.cs(1 hunks)src/ReinforcementLearning/Agents/DQN/DQNOptions.cs(1 hunks)src/ReinforcementLearning/Agents/PPO/PPOAgent.cs(1 hunks)src/ReinforcementLearning/Agents/PPO/PPOOptions.cs(1 hunks)src/ReinforcementLearning/Agents/README.md(1 hunks)src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs(1 hunks)src/ReinforcementLearning/Agents/SAC/SACOptions.cs(1 hunks)src/ReinforcementLearning/Common/Trajectory.cs(1 hunks)src/ReinforcementLearning/Environments/CartPoleEnvironment.cs(1 hunks)src/ReinforcementLearning/INTEGRATION_PLAN.md(1 hunks)src/ReinforcementLearning/Interfaces/IEnvironment.cs(1 hunks)src/ReinforcementLearning/Interfaces/IRLAgent.cs(1 hunks)src/ReinforcementLearning/ReplayBuffers/Experience.cs(1 hunks)src/ReinforcementLearning/ReplayBuffers/IReplayBuffer.cs(1 hunks)src/ReinforcementLearning/ReplayBuffers/UniformReplayBuffer.cs(1 hunks)tests/AiDotNet.Tests/UnitTests/ReinforcementLearning/CartPoleEnvironmentTests.cs(1 hunks)tests/AiDotNet.Tests/UnitTests/ReinforcementLearning/EpsilonGreedyPolicyTests.cs(1 hunks)tests/AiDotNet.Tests/UnitTests/ReinforcementLearning/UniformReplayBufferTests.cs(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (16)
src/ReinforcementLearning/ReplayBuffers/IReplayBuffer.cs (2)
src/ReinforcementLearning/Agents/DQN/DQNAgent.cs (2)
T(150-228)T(410-423)src/ReinforcementLearning/ReplayBuffers/UniformReplayBuffer.cs (3)
Add(47-59)CanSample(85-88)Clear(91-95)
tests/AiDotNet.Tests/UnitTests/ReinforcementLearning/UniformReplayBufferTests.cs (1)
src/ReinforcementLearning/ReplayBuffers/UniformReplayBuffer.cs (5)
UniformReplayBuffer(18-96)UniformReplayBuffer(35-44)Add(47-59)CanSample(85-88)Clear(91-95)
src/ReinforcementLearning/ReplayBuffers/Experience.cs (3)
src/ReinforcementLearning/Agents/DQN/DQNAgent.cs (3)
T(150-228)T(410-423)Vector(121-140)src/ReinforcementLearning/Environments/CartPoleEnvironment.cs (3)
Vector(89-99)Vector(102-162)Vector(176-184)src/ReinforcementLearning/Interfaces/IEnvironment.cs (2)
Vector(58-58)Vector(78-78)
src/ReinforcementLearning/ReplayBuffers/UniformReplayBuffer.cs (1)
src/ReinforcementLearning/ReplayBuffers/IReplayBuffer.cs (4)
List(44-44)Add(37-37)CanSample(51-51)Clear(56-56)
src/ReinforcementLearning/Interfaces/IRLAgent.cs (2)
src/ReinforcementLearning/Agents/DQN/DQNAgent.cs (5)
T(150-228)T(410-423)IFullModel(328-351)Vector(121-140)StoreExperience(143-147)src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs (8)
T(134-134)T(283-294)IFullModel(238-238)Vector(118-118)Vector(149-152)StoreExperience(128-128)Train(167-172)ResetEpisode(139-142)
src/ReinforcementLearning/Interfaces/IEnvironment.cs (1)
src/ReinforcementLearning/Environments/CartPoleEnvironment.cs (5)
Vector(89-99)Vector(102-162)Vector(176-184)Seed(165-168)Close(171-174)
src/ReinforcementLearning/Agents/PPO/PPOOptions.cs (2)
src/ReinforcementLearning/Agents/PPO/PPOAgent.cs (11)
T(221-233)T(235-243)T(245-269)T(272-312)T(372-431)T(433-471)T(716-725)T(727-730)T(732-737)T(739-745)T(747-757)src/Helpers/NeuralNetworkHelper.cs (1)
ILossFunction(49-76)
tests/AiDotNet.Tests/UnitTests/ReinforcementLearning/CartPoleEnvironmentTests.cs (2)
src/ReinforcementLearning/Environments/CartPoleEnvironment.cs (4)
CartPoleEnvironment(36-185)CartPoleEnvironment(77-86)Seed(165-168)Close(171-174)src/ReinforcementLearning/Interfaces/IEnvironment.cs (2)
Seed(84-84)Close(89-89)
src/ReinforcementLearning/Environments/CartPoleEnvironment.cs (1)
src/ReinforcementLearning/Interfaces/IEnvironment.cs (4)
Vector(58-58)Vector(78-78)Seed(84-84)Close(89-89)
src/ReinforcementLearning/Agents/PPO/PPOAgent.cs (4)
src/ReinforcementLearning/Agents/PPO/PPOOptions.cs (2)
PPOOptions(29-192)PPOOptions(162-191)src/ReinforcementLearning/Common/Trajectory.cs (4)
Trajectory(27-114)Trajectory(72-80)AddStep(85-93)Clear(103-113)src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs (1)
ReinforcementLearningOptions(316-382)src/Helpers/NeuralNetworkHelper.cs (1)
ILossFunction(49-76)
src/ReinforcementLearning/Agents/DQN/DQNOptions.cs (2)
src/ReinforcementLearning/Agents/DQN/DQNAgent.cs (2)
T(150-228)T(410-423)src/Helpers/NeuralNetworkHelper.cs (1)
ILossFunction(49-76)
src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs (6)
src/PredictionModelBuilder.cs (1)
T(717-728)src/ReinforcementLearning/Agents/DQN/DQNAgent.cs (13)
T(150-228)T(410-423)Vector(121-140)StoreExperience(143-147)ModelMetadata(241-250)Serialize(253-277)Deserialize(280-302)Matrix(305-314)Matrix(354-368)SetParameters(317-325)Dictionary(231-238)IFullModel(328-351)ApplyGradients(371-383)src/ReinforcementLearning/Interfaces/IRLAgent.cs (5)
T(65-65)Vector(41-41)StoreExperience(55-55)ResetEpisode(76-76)Dictionary(71-71)src/Helpers/MathHelper.cs (1)
INumericOperations(36-66)src/ReinforcementLearning/ReplayBuffers/UniformReplayBuffer.cs (2)
List(62-82)Add(47-59)src/Helpers/NeuralNetworkHelper.cs (1)
ILossFunction(49-76)
src/ReinforcementLearning/Agents/SAC/SACOptions.cs (1)
src/Helpers/NeuralNetworkHelper.cs (1)
ILossFunction(49-76)
src/PredictionModelBuilder.cs (7)
src/ReinforcementLearning/Agents/DQN/DQNAgent.cs (3)
T(150-228)T(410-423)StoreExperience(143-147)src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs (5)
T(134-134)T(283-294)ResetEpisode(139-142)StoreExperience(128-128)Train(167-172)src/ReinforcementLearning/Interfaces/IRLAgent.cs (3)
T(65-65)ResetEpisode(76-76)StoreExperience(55-55)src/Interfaces/IPredictionModelBuilder.cs (21)
Task(642-642)Task(678-678)PredictionModelResult(236-236)PredictionModelResult(278-278)TOutput(208-208)IPredictionModelBuilder(41-41)IPredictionModelBuilder(57-57)IPredictionModelBuilder(72-72)IPredictionModelBuilder(87-87)IPredictionModelBuilder(105-105)IPredictionModelBuilder(139-139)IPredictionModelBuilder(157-157)IPredictionModelBuilder(176-176)IPredictionModelBuilder(192-192)IPredictionModelBuilder(294-294)IPredictionModelBuilder(311-311)IPredictionModelBuilder(340-340)IPredictionModelBuilder(366-370)IPredictionModelBuilder(400-400)IPredictionModelBuilder(416-416)IPredictionModelBuilder(521-524)src/Models/Results/PredictionModelResult.cs (9)
PredictionModelResult(350-358)PredictionModelResult(375-402)PredictionModelResult(443-471)PredictionModelResult(498-500)PredictionModelResult(1314-1337)TOutput(565-581)Train(676-682)IEnumerable(980-988)IEnumerable(1427-1460)src/ReinforcementLearning/ReplayBuffers/IReplayBuffer.cs (2)
List(44-44)Add(37-37)src/ReinforcementLearning/ReplayBuffers/UniformReplayBuffer.cs (2)
List(62-82)Add(47-59)
src/ReinforcementLearning/Common/Trajectory.cs (1)
src/ReinforcementLearning/Agents/PPO/PPOAgent.cs (15)
T(221-233)T(235-243)T(245-269)T(272-312)T(372-431)T(433-471)T(716-725)T(727-730)T(732-737)T(739-745)T(747-757)Vector(142-156)Vector(158-179)Vector(181-206)Vector(661-680)
src/ReinforcementLearning/Agents/DQN/DQNAgent.cs (5)
src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs (19)
T(134-134)T(283-294)ReinforcementLearningAgentBase(34-310)ReinforcementLearningAgentBase(95-108)ReinforcementLearningOptions(316-382)Vector(118-118)Vector(149-152)StoreExperience(128-128)Train(167-172)Dictionary(225-233)Dictionary(269-278)ModelMetadata(162-162)Serialize(177-177)Deserialize(182-182)Matrix(187-187)Matrix(243-246)SetParameters(192-192)IFullModel(238-238)ApplyGradients(251-251)src/ReinforcementLearning/Interfaces/IRLAgent.cs (4)
T(65-65)Vector(41-41)StoreExperience(55-55)Dictionary(71-71)src/ReinforcementLearning/Agents/DQN/DQNOptions.cs (2)
DQNOptions(25-154)DQNOptions(134-153)src/ReinforcementLearning/ReplayBuffers/UniformReplayBuffer.cs (5)
UniformReplayBuffer(18-96)UniformReplayBuffer(35-44)Add(47-59)List(62-82)CanSample(85-88)src/ReinforcementLearning/ReplayBuffers/IReplayBuffer.cs (3)
Add(37-37)List(44-44)CanSample(51-51)
🪛 GitHub Actions: Build
src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs
[error] 366-366: CS8180: { or ; or => expected
🪛 GitHub Actions: Quality Gates (.NET)
src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs
[error] 366-366: CS8180: { or ; or => expected
🪛 GitHub Check: Build All Frameworks
src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs
[failure] 366-366:
{ or ; or => expected
[failure] 366-366:
{ or ; or => expected
[failure] 366-366:
{ or ; or => expected
[failure] 366-366:
{ or ; or => expected
🪛 GitHub Check: Publish Size Analysis
src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs
[failure] 366-366:
{ or ; or => expected
tests/AiDotNet.Tests/UnitTests/ReinforcementLearning/CartPoleEnvironmentTests.cs
Outdated
Show resolved
Hide resolved
tests/AiDotNet.Tests/UnitTests/ReinforcementLearning/UniformReplayBufferTests.cs
Show resolved
Hide resolved
tests/AiDotNet.Tests/UnitTests/ReinforcementLearning/UniformReplayBufferTests.cs
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
♻️ Duplicate comments (2)
src/ReinforcementLearning/Agents/PPO/PPOAgent.cs (2)
477-511: Policy update still ignores PPO’s clipped loss gradient.
UpdatePolicyNetworkcontinues to backpropgradOutput = action * advantage, which is the same oversimplified gradient the earlier review flagged—no dependence on the clipped surrogate you compute inUpdateNetworks. The policy will not optimize PPO’s objective. Please backprop the derivative of the clipped surrogate loss (or feed the loss into the network’s optimizer) instead of this placeholder gradient.
323-333: Fix GAE bootstrap for terminal steps.
When_trajectory.Dones[t]is true we must not bootstrap from_trajectory.Values[t + 1], otherwise we leak value estimates from the next episode into the tail of a finished episode. That breaks the TD error and skews the advantages/returns. Please zeronextValuewhenever the current step is terminal, even in the non-final branch. A minimal fix is:- if (t == _trajectory.Length - 1) - { - nextValue = _trajectory.Dones[t] ? NumOps.Zero : _trajectory.Values[t]; - } - else - { - nextValue = _trajectory.Values[t + 1]; - } + if (t == _trajectory.Length - 1) + { + nextValue = NumOps.Zero; + } + else + { + nextValue = _trajectory.Dones[t] ? NumOps.Zero : _trajectory.Values[t + 1]; + }
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
src/ReinforcementLearning/Agents/A2C/A2CAgent.cs(1 hunks)src/ReinforcementLearning/Agents/A2C/A2COptions.cs(1 hunks)src/ReinforcementLearning/Agents/PPO/PPOAgent.cs(1 hunks)src/ReinforcementLearning/Agents/REINFORCE/REINFORCEAgent.cs(1 hunks)src/ReinforcementLearning/Agents/REINFORCE/REINFORCEOptions.cs(1 hunks)src/ReinforcementLearning/Agents/SAC/SACAgent.cs(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (5)
src/ReinforcementLearning/Agents/A2C/A2COptions.cs (2)
src/ReinforcementLearning/Agents/A2C/A2CAgent.cs (3)
T(169-200)T(203-265)T(340-371)src/Helpers/NeuralNetworkHelper.cs (1)
ILossFunction(49-76)
src/ReinforcementLearning/Agents/PPO/PPOAgent.cs (7)
src/ReinforcementLearning/Agents/A2C/A2CAgent.cs (15)
T(169-200)T(203-265)T(340-371)NeuralNetwork(66-87)NeuralNetwork(89-109)Vector(112-124)Vector(126-134)Vector(136-159)Vector(486-511)SampleCategorical(513-525)ArgMax(527-536)StoreExperience(162-167)ComputeAdvantages(267-310)UpdatePolicyNetwork(312-324)UpdateValueNetwork(326-338)src/ReinforcementLearning/Agents/REINFORCE/REINFORCEAgent.cs (11)
T(153-187)T(190-237)NeuralNetwork(69-93)Vector(96-108)Vector(110-118)Vector(120-143)Vector(391-416)SampleCategorical(418-430)ArgMax(432-441)StoreExperience(146-151)UpdatePolicyNetwork(280-292)src/ReinforcementLearning/Agents/PPO/PPOOptions.cs (2)
PPOOptions(29-192)PPOOptions(162-191)src/ReinforcementLearning/Common/Trajectory.cs (4)
Trajectory(27-114)Trajectory(72-80)AddStep(85-93)Clear(103-113)src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs (1)
ReinforcementLearningOptions(316-382)src/Helpers/MathHelper.cs (1)
MathHelper(19-992)src/Helpers/StatisticsHelper.cs (1)
StatisticsHelper(17-6685)
src/ReinforcementLearning/Agents/SAC/SACAgent.cs (5)
src/ReinforcementLearning/Agents/A2C/A2CAgent.cs (12)
T(169-200)T(203-265)T(340-371)Vector(112-124)Vector(126-134)Vector(136-159)Vector(486-511)Serialize(394-411)Deserialize(414-429)Matrix(432-445)Matrix(473-477)SetParameters(448-462)src/ReinforcementLearning/Agents/PPO/PPOAgent.cs (16)
T(222-234)T(236-244)T(246-270)T(273-313)T(376-435)T(437-475)T(712-721)Vector(143-157)Vector(159-180)Vector(182-207)Vector(665-684)Serialize(564-582)Deserialize(585-601)Matrix(604-619)Matrix(649-656)SetParameters(622-638)src/ReinforcementLearning/Agents/REINFORCE/REINFORCEAgent.cs (11)
T(153-187)T(190-237)Vector(96-108)Vector(110-118)Vector(120-143)Vector(391-416)Serialize(315-328)Deserialize(331-342)Matrix(345-356)Matrix(378-382)SetParameters(359-367)src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs (1)
ReinforcementLearningOptions(316-382)src/Helpers/MathHelper.cs (1)
MathHelper(19-992)
src/ReinforcementLearning/Agents/REINFORCE/REINFORCEAgent.cs (8)
src/ReinforcementLearning/Agents/A2C/A2CAgent.cs (13)
T(169-200)T(203-265)T(340-371)NeuralNetwork(66-87)NeuralNetwork(89-109)Vector(112-124)Vector(126-134)Vector(136-159)Vector(486-511)SampleCategorical(513-525)ArgMax(527-536)StoreExperience(162-167)UpdatePolicyNetwork(312-324)src/ReinforcementLearning/Agents/PPO/PPOAgent.cs (17)
T(222-234)T(236-244)T(246-270)T(273-313)T(376-435)T(437-475)T(712-721)NeuralNetwork(81-114)NeuralNetwork(116-140)Vector(143-157)Vector(159-180)Vector(182-207)Vector(665-684)SampleCategorical(686-699)ArgMax(701-710)StoreExperience(210-220)UpdatePolicyNetwork(477-511)src/ReinforcementLearning/Agents/SAC/SACAgent.cs (9)
T(213-252)T(254-318)T(320-358)NeuralNetwork(97-121)NeuralNetwork(123-148)Vector(151-156)Vector(158-204)Vector(426-434)StoreExperience(207-210)src/ReinforcementLearning/Agents/REINFORCE/REINFORCEOptions.cs (2)
REINFORCEOptions(25-47)REINFORCEOptions(35-46)src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs (1)
ReinforcementLearningOptions(316-382)src/Helpers/MathHelper.cs (1)
MathHelper(19-992)src/ReinforcementLearning/Common/Trajectory.cs (1)
AddStep(85-93)src/Helpers/StatisticsHelper.cs (1)
StatisticsHelper(17-6685)
src/ReinforcementLearning/Agents/REINFORCE/REINFORCEOptions.cs (1)
src/ReinforcementLearning/Agents/REINFORCE/REINFORCEAgent.cs (2)
T(153-187)T(190-237)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 14
♻️ Duplicate comments (1)
src/ReinforcementLearning/Agents/DoubleDQN/DoubleDQNOptions.cs (1)
11-18: Same validation gaps as DQNOptions.This class has the same missing validation for required properties (StateSize, ActionSize, LearningRate, DiscountFactor) and epsilon parameters as noted in
DQNOptions.cs. If you implement validation or userequiredkeyword there, apply it here as well for consistency.
🧹 Nitpick comments (7)
src/ReinforcementLearning/Agents/TD3/TD3Options.cs (2)
11-16: Consider immutability and validation for core configuration properties.These core properties use
setaccessors, allowing mutation after initialization. For configuration objects,initaccessors are typically preferred to ensure immutability after construction.Additionally, these properties have no defaults and no validation:
StateSizeandActionSizeshould be positive- Learning rates should be positive
DiscountFactortypically ranges from 0 to 1TargetUpdateTautypically ranges from 0 to 1Consider this refactor:
- public int StateSize { get; set; } - public int ActionSize { get; set; } - public T ActorLearningRate { get; set; } - public T CriticLearningRate { get; set; } - public T DiscountFactor { get; set; } - public T TargetUpdateTau { get; set; } + public required int StateSize { get; init; } + public required int ActionSize { get; init; } + public required T ActorLearningRate { get; init; } + public required T CriticLearningRate { get; init; } + public required T DiscountFactor { get; init; } + public required T TargetUpdateTau { get; init; }Then add validation in a constructor or validation method to ensure values are within valid ranges.
22-24: Consider type consistency for noise parameters.The noise parameters use
doubletype, while other hyperparameters like learning rates use the generic typeT. While this may be intentional (noise is applied in action space), it creates type inconsistency within the configuration class.For better consistency and flexibility, consider using the generic type
Tfor these parameters as well, especially since TD3 can operate on different numeric types.- public double ExplorationNoise { get; set; } = 0.1; - public double TargetPolicyNoise { get; set; } = 0.2; - public double TargetNoiseClip { get; set; } = 0.5; + public T ExplorationNoise { get; set; } = T.CreateChecked(0.1); + public T TargetPolicyNoise { get; set; } = T.CreateChecked(0.2); + public T TargetNoiseClip { get; set; } = T.CreateChecked(0.5);src/ReinforcementLearning/Agents/DQN/DQNOptions.cs (3)
11-14: Consider adding validation or using required properties.The core configuration properties lack defaults and validation. Invalid values (e.g.,
StateSize <= 0, negativeLearningRate) could cause runtime failures in the agent.Consider either:
- Adding init-time validation in a constructor/validator method
- Using C# 11
requiredkeyword to enforce explicit initialization- Documenting constraints via XML comments
Example with
required:- public int StateSize { get; set; } - public int ActionSize { get; set; } - public T LearningRate { get; set; } - public T DiscountFactor { get; set; } + /// <summary>State space dimension. Must be positive.</summary> + public required int StateSize { get; set; } + + /// <summary>Action space dimension. Must be positive.</summary> + public required int ActionSize { get; set; } + + /// <summary>Learning rate for gradient updates. Must be positive.</summary> + public required T LearningRate { get; set; } + + /// <summary>Discount factor (gamma) for future rewards. Typically in [0, 1].</summary> + public required T DiscountFactor { get; set; }
15-17: Add validation for epsilon exploration parameters.The epsilon parameters lack validation. Invalid values (e.g.,
EpsilonStart < EpsilonEnd,EpsilonDecay > 1.0) could break exploration behavior.Add validation or document constraints:
+ /// <summary>Initial exploration rate. Must be in [0, 1] and >= EpsilonEnd.</summary> public double EpsilonStart { get; set; } = 1.0; + + /// <summary>Final exploration rate. Must be in [0, 1] and <= EpsilonStart.</summary> public double EpsilonEnd { get; set; } = 0.01; + + /// <summary>Exploration decay rate per step. Must be in (0, 1].</summary> public double EpsilonDecay { get; set; } = 0.995;
18-20: Inconsistent defaults between DQN variants.
DQNOptionsusesBatchSize = 64andReplayBufferSize = 100000, whileDoubleDQNOptionsusesBatchSize = 32andReplayBufferSize = 10000. This inconsistency may confuse users choosing between agents.Consider aligning defaults across variants unless there's a specific algorithmic reason for the differences. Document the rationale if intentional.
src/ReinforcementLearning/Agents/DoubleDQN/DoubleDQNOptions.cs (1)
9-25: Consider extracting a base options class to reduce duplication.
DoubleDQNOptions<T>andDQNOptions<T>share nearly identical structure with only default value differences. This duplication increases maintenance burden.Consider introducing a shared base class:
public abstract class DQNOptionsBase<T> { public required int StateSize { get; set; } public required int ActionSize { get; set; } public required T LearningRate { get; set; } public required T DiscountFactor { get; set; } public ILossFunction<T> LossFunction { get; set; } = new MeanSquaredError<T>(); public double EpsilonStart { get; set; } = 1.0; public double EpsilonEnd { get; set; } = 0.01; public double EpsilonDecay { get; set; } = 0.995; public abstract int BatchSize { get; set; } public abstract int ReplayBufferSize { get; set; } public int TargetUpdateFrequency { get; set; } = 1000; public int WarmupSteps { get; set; } = 1000; public abstract List<int> HiddenLayers { get; set; } public int? Seed { get; set; } } public class DoubleDQNOptions<T> : DQNOptionsBase<T> { public override int BatchSize { get; set; } = 32; public override int ReplayBufferSize { get; set; } = 10000; public override List<int> HiddenLayers { get; set; } = [64, 64]; }This reduces duplication and makes it easier to maintain consistent behavior across variants.
src/ReinforcementLearning/Agents/DuelingDQN/DuelingDQNAgent.cs (1)
241-246: Deserialized configuration values are discarded.Lines 241-244 read
stateSize,actionSize,learningRate, anddiscountFactorbut don't validate or use them. This means deserialization doesn't verify compatibility with the current agent's configuration.Consider validating these values match the agent's current options:
- reader.ReadInt32(); // stateSize - reader.ReadInt32(); // actionSize - reader.ReadDouble(); // learningRate - reader.ReadDouble(); // discountFactor + var stateSize = reader.ReadInt32(); + var actionSize = reader.ReadInt32(); + var learningRate = reader.ReadDouble(); + var discountFactor = reader.ReadDouble(); + + if (stateSize != _options.StateSize || actionSize != _options.ActionSize) + throw new InvalidOperationException("Serialized model dimensions don't match current options");
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (9)
src/Enums/ModelType.cs(2 hunks)src/ReinforcementLearning/Agents/DDPG/DDPGAgent.cs(1 hunks)src/ReinforcementLearning/Agents/DDPG/DDPGOptions.cs(1 hunks)src/ReinforcementLearning/Agents/DQN/DQNOptions.cs(1 hunks)src/ReinforcementLearning/Agents/DoubleDQN/DoubleDQNAgent.cs(1 hunks)src/ReinforcementLearning/Agents/DoubleDQN/DoubleDQNOptions.cs(1 hunks)src/ReinforcementLearning/Agents/DuelingDQN/DuelingDQNAgent.cs(1 hunks)src/ReinforcementLearning/Agents/DuelingDQN/DuelingDQNOptions.cs(1 hunks)src/ReinforcementLearning/Agents/TD3/TD3Options.cs(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (8)
src/ReinforcementLearning/Agents/DoubleDQN/DoubleDQNOptions.cs (2)
src/ReinforcementLearning/Agents/DoubleDQN/DoubleDQNAgent.cs (1)
T(136-206)src/Helpers/NeuralNetworkHelper.cs (1)
ILossFunction(49-76)
src/ReinforcementLearning/Agents/DuelingDQN/DuelingDQNOptions.cs (1)
src/Helpers/NeuralNetworkHelper.cs (1)
ILossFunction(49-76)
src/ReinforcementLearning/Agents/DDPG/DDPGAgent.cs (5)
src/ReinforcementLearning/Agents/DuelingDQN/DuelingDQNAgent.cs (15)
T(124-187)CopyNetworkWeights(321-325)Vector(99-115)Vector(403-447)Backward(449-472)ModelMetadata(200-209)Serialize(212-233)Serialize(521-527)Deserialize(236-255)Deserialize(529-532)Matrix(258-261)Matrix(298-312)Matrix(491-514)SetParameters(264-267)ApplyGradients(315-318)src/ReinforcementLearning/Agents/DDPG/DDPGOptions.cs (1)
DDPGOptions(9-25)src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs (1)
ReinforcementLearningOptions(316-382)src/Helpers/MathHelper.cs (2)
MathHelper(19-992)INumericOperations(36-66)src/Helpers/NeuralNetworkHelper.cs (1)
ILossFunction(49-76)
src/ReinforcementLearning/Agents/DoubleDQN/DoubleDQNAgent.cs (2)
src/ReinforcementLearning/Agents/DoubleDQN/DoubleDQNOptions.cs (1)
DoubleDQNOptions(9-25)src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs (1)
ReinforcementLearningOptions(316-382)
src/ReinforcementLearning/Agents/DuelingDQN/DuelingDQNAgent.cs (4)
src/ReinforcementLearning/Agents/DDPG/DDPGAgent.cs (17)
T(172-197)T(199-239)T(241-271)CopyNetworkWeights(446-449)Vector(144-163)Vector(312-320)Vector(474-491)StoreExperience(166-169)Dictionary(323-329)ModelMetadata(332-341)Serialize(344-365)Deserialize(368-387)Matrix(390-403)Matrix(434-438)SetParameters(406-423)IFullModel(426-431)ApplyGradients(441-444)src/ReinforcementLearning/Agents/DuelingDQN/DuelingDQNOptions.cs (1)
DuelingDQNOptions(9-27)src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs (1)
ReinforcementLearningOptions(316-382)src/Helpers/MathHelper.cs (1)
INumericOperations(36-66)
src/ReinforcementLearning/Agents/DQN/DQNOptions.cs (6)
src/ReinforcementLearning/Agents/DQN/DQNAgent.cs (2)
T(150-228)T(410-423)src/Helpers/NeuralNetworkHelper.cs (1)
ILossFunction(49-76)src/ReinforcementLearning/ReplayBuffers/IReplayBuffer.cs (1)
List(44-44)src/ReinforcementLearning/ReplayBuffers/UniformReplayBuffer.cs (1)
List(62-82)src/ReinforcementLearning/Environments/CartPoleEnvironment.cs (1)
Seed(165-168)src/ReinforcementLearning/Interfaces/IEnvironment.cs (1)
Seed(84-84)
src/ReinforcementLearning/Agents/TD3/TD3Options.cs (1)
src/Helpers/NeuralNetworkHelper.cs (1)
ILossFunction(49-76)
src/ReinforcementLearning/Agents/DDPG/DDPGOptions.cs (2)
src/ReinforcementLearning/Agents/DDPG/DDPGAgent.cs (3)
T(172-197)T(199-239)T(241-271)src/Helpers/NeuralNetworkHelper.cs (1)
ILossFunction(49-76)
🔇 Additional comments (7)
src/ReinforcementLearning/Agents/TD3/TD3Options.cs (2)
18-21: LGTM: Standard TD3 hyperparameter defaults.The default values align well with standard TD3 configurations:
- Batch size of 256 is common for off-policy RL
- Large replay buffer (1M) is appropriate for sample efficiency
- 25K warmup steps allow sufficient exploration before training
- Policy update frequency of 2 implements TD3's signature "delayed" policy updates
25-27: LGTM: Clean architecture and seeding configuration.The hidden layer configurations use modern collection expression syntax and provide sensible defaults (two layers of 256 units each) for both actor and critic networks. The nullable
Seedproperty appropriately allows optional reproducibility control.src/ReinforcementLearning/Agents/DuelingDQN/DuelingDQNAgent.cs (5)
56-96: Constructor initialization looks correct.The constructor properly initializes the dueling architecture with shared, value, and advantage streams, and correctly synchronizes the target network weights.
99-115: Epsilon-greedy action selection implemented correctly.The implementation properly balances exploration (random actions) and exploitation (greedy Q-value maximization) with appropriate one-hot encoding.
124-187: Train method implements Double DQN with dueling architecture correctly.The training loop properly implements:
- Warmup period checks
- Double DQN target computation (online network selects, target evaluates)
- Loss computation and gradient backpropagation
- Periodic target network synchronization
- Epsilon decay
The implementation depends on
DuelingNetwork.BackwardandUpdateWeightsbeing correctly implemented (see later comments on those methods).
327-342: ArgMax implementation is correct.The utility method properly finds the maximum Q-value index for action selection.
403-447: Forward pass implements dueling architecture correctly.The dueling network properly:
- Processes state through shared feature layers
- Splits into separate value and advantage streams
- Applies mean advantage centering for identifiability:
Q(s,a) = V(s) + (A(s,a) - mean(A))This matches the Dueling DQN paper (Wang et al., 2016).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 24
♻️ Duplicate comments (19)
src/ReinforcementLearning/Agents/DDPG/DDPGAgent.cs (5)
59-90: The null-handling issue flagged in previous reviews remains unresolved.Line 60-69 dereference
optionsto construct the base class arguments before the null guard on line 71 executes. A null caller will encounterNullReferenceExceptioninstead ofArgumentNullException. The previous review provided a fix using a helper method to validateoptionsbefore the base constructor call.
242-272: The policy gradient implementation issue flagged in previous reviews remains unresolved.Lines 258-263 compute actor gradients by multiplying the scalar Q-value by a hard-coded constant (
-0.01), completely bypassing the critic's gradient ∂Q/∂a. Without backpropagating through the critic to obtain the true action gradient, the actor receives no meaningful learning signal and DDPG cannot function. The previous review correctly identified this as a critical algorithmic flaw requiring proper gradient computation from the critic network.
434-439: TheComputeGradientsplaceholder issue flagged in previous reviews remains unresolved.Line 438 returns model weights via
GetParameters()where gradients are expected. Any caller treating these weights as gradients will perform nonsensical updates. The previous review recommended either implementing proper gradient computation or explicitly throwingNotSupportedExceptionto prevent misuse.
441-445: TheApplyGradientsempty implementation issue flagged in previous reviews remains unresolved.Lines 442-445 leave the method body empty with only a comment. This silently ignores gradient updates from any external caller, creating a false assumption of success. The previous review correctly recommended either throwing
NotSupportedExceptionor implementing the update logic.
475-492: The Ornstein-Uhlenbeck diffusion term issue flagged in previous reviews remains unresolved.Lines 481-485 subtract the Gaussian noise term, producing
dx = -θx - σN(0,1), which is incorrect for the OU process. The stochastic diffusion term must be added, not subtracted. The previous review provided a fix separating drift and diffusion with correct addition. Additionally, line 487 usesNumOps.Addinstead of_numOps.Add, creating an inconsistent field reference.Apply this additional fix for the field reference:
- _state[i] = NumOps.Add(_state[i], dx); + _state[i] = _numOps.Add(_state[i], dx);src/ReinforcementLearning/Agents/SAC/SACAgent.cs (1)
321-359: Actor gradient remains a placeholder, not SAC's policy gradient.Lines 346-352 still set every gradient component to
policyLoss * 0.01, which is not ∇_θ[α log π_θ(a|s) − Q(s,a)]. This placeholder gradient will not train the policy correctly. The proper gradient requires backpropagating the scalar policy loss through the network—either via automatic differentiation of the loss w.r.t. policy outputs, or by manually deriving ∂loss/∂mean and ∂loss/∂log_std, then callingBackwardwith those gradients.src/ReinforcementLearning/Agents/REINFORCE/REINFORCEAgent.cs (1)
207-233: Policy gradient still ignores the logits.The backward vector is filled with
loss * 0.01, identical for every logit dimension, so∇θ log π(a|s)never factors into the update. This is the same issue coderabbitai flagged earlier and the implementation remains unchanged. Please backpropagate the actual log-prob gradients (e.g.,(p - a)for categorical, Gaussian derivative for continuous) before scaling by the return; otherwise REINFORCE cannot learn.src/ReinforcementLearning/Agents/A2C/A2CAgent.cs (1)
229-325: Still no backprop before applying gradients.Line 259 still calls
UpdatePolicyNetwork()without ever running_policyNetwork.Backward(...)(or zeroing grads) after the loss computation, soGetFlattenedGradients()is returning stale/zero data and the policy update is a no-op—the exact issue noted in the previous review. Please backprop the policy and value losses through their respective networks before fetching gradients. In addition, Line 321 currently addslearningRate * gradient; once real gradients are in place you must subtract that term (as in the PPO/REINFORCE agents) or you will ascend the loss and the actor will diverge.src/Models/Options/TD3Options.cs (1)
17-17: Use the existing MeanSquaredErrorLoss implementation
MeanSquaredError<T>isn’t defined anywhere in the solution—the loss implementation isMeanSquaredErrorLoss<T>(see NeuralNetworkHelper defaults). Leaving this initializer as-is breaks compilation, so please swap it to the actual loss type.- public ILossFunction<T> CriticLossFunction { get; set; } = new MeanSquaredError<T>(); + public ILossFunction<T> CriticLossFunction { get; set; } = new MeanSquaredErrorLoss<T>();src/ReinforcementLearning/Agents/DoubleDQN/DoubleDQNAgent.cs (2)
59-75: Don't dereferenceoptionsbefore null-checking it.
optionsis dereferenced in the base-call object initializer before we guard against null, so a null caller still crashes withNullReferenceException. Hoist the guard into a helper that validates and builds the base options first.- public DoubleDQNAgent(DoubleDQNOptions<T> options) - : base(new ReinforcementLearningOptions<T> - { - LearningRate = options.LearningRate, - DiscountFactor = options.DiscountFactor, - LossFunction = options.LossFunction, - Seed = options.Seed, - BatchSize = options.BatchSize, - ReplayBufferSize = options.ReplayBufferSize, - TargetUpdateFrequency = options.TargetUpdateFrequency, - WarmupSteps = options.WarmupSteps, - EpsilonStart = options.EpsilonStart, - EpsilonEnd = options.EpsilonEnd, - EpsilonDecay = options.EpsilonDecay - }) + public DoubleDQNAgent(DoubleDQNOptions<T> options) + : base(CreateBaseOptions(options)) { - _options = options ?? throw new ArgumentNullException(nameof(options)); + _options = options; @@ + private static ReinforcementLearningOptions<T> CreateBaseOptions(DoubleDQNOptions<T> options) + { + ArgumentNullException.ThrowIfNull(options); + + return new ReinforcementLearningOptions<T> + { + LearningRate = options.LearningRate, + DiscountFactor = options.DiscountFactor, + LossFunction = options.LossFunction, + Seed = options.Seed, + BatchSize = options.BatchSize, + ReplayBufferSize = options.ReplayBufferSize, + TargetUpdateFrequency = options.TargetUpdateFrequency, + WarmupSteps = options.WarmupSteps, + EpsilonStart = options.EpsilonStart, + EpsilonEnd = options.EpsilonEnd, + EpsilonDecay = options.EpsilonDecay + }; + }
290-298: Keep target network in sync after external parameter injection.
SetParametersonly updates the online network, leaving the target network stale and causing divergent behaviour right after a parameter load. Copy the fresh weights into_targetNetworktoo._qNetwork.UpdateParameters(vector); + CopyNetworkWeights(_qNetwork, _targetNetwork);src/ReinforcementLearning/Agents/PPO/PPOAgent.cs (2)
327-334: Fix terminal-state bootstrap in GAE.Line 329 reuses
_trajectory.Values[t]when the last step is done, and Line 333 always bootstraps from_trajectory.Values[t + 1]even if_trajectory.Dones[t]is true. Terminal transitions must usenextValue = 0, otherwise advantages leak value estimates across episode boundaries. Please adopt the fix previously suggested so both branches zero outnextValuewhen the transition is terminal.
494-511: Replace the placeholder policy gradient with PPO’s clipped objective.
UpdatePolicyNetworkstill sendsaction * advantage(Lines 494-499) as the output gradient, ignoring the clipped probability ratio you computed inUpdateNetworks. This is the same placeholder flagged earlier—the policy update no longer matches PPO and the agent cannot learn correctly. Please backpropagate the gradient of the clipped surrogate (or equivalent log-prob formulation) instead of this simplified stand-in.src/ReinforcementLearning/Agents/DuelingDQN/DuelingDQNAgent.cs (6)
316-319: Critical bug: ApplyGradients replaces parameters instead of applying gradient updates.This issue was previously flagged and remains unfixed. The implementation incorrectly replaces all parameters with gradient values instead of performing gradient descent (
params -= learningRate * gradients).
450-473: Critical: Backward method is incomplete and won't update network weights.This issue was previously flagged and remains unfixed. The method computes gradient distributions for value and advantage streams but doesn't propagate them through the layer hierarchy (line 472 explicitly notes this). Without proper backpropagation, gradients never reach layer weights and the network cannot learn.
475-490: Critical: UpdateWeights is a stub - network parameters never update.This issue was previously flagged and remains unfixed. The method contains empty loops with placeholder comments. Without weight updates, the network cannot learn from backpropagated gradients.
492-515: Critical: GetFlattenedParameters is incomplete - parameter management will fail.This issue was previously flagged and remains unfixed. The method has empty loops with placeholder comments and returns a matrix based on an empty list, breaking serialization, cloning, and parameter inspection.
517-520: Critical: SetFlattenedParameters is a stub - cloning and parameter loading will fail.This issue was previously flagged and remains unfixed. The empty stub prevents restoring network state, breaking
Clone(),Deserialize(), and external parameter injection.
522-533: Critical: Serialize/Deserialize are stubs - model persistence is broken.This issue was previously flagged and remains unfixed. Empty implementations prevent saving and loading trained models, making the agent unusable in production scenarios.
🧹 Nitpick comments (7)
src/Models/Options/MuZeroOptions.cs (3)
1-1: Remove unused using directive.The
AiDotNet.LossFunctionsnamespace is imported but not used anywhere in this file.Apply this diff:
-using AiDotNet.LossFunctions; - namespace AiDotNet.Models.Options;
33-34: Consider validation or required initialization for ObservationSize and ActionSize.These properties are critical for MuZero agent configuration but have no defaults or validation. If not set explicitly, they will remain at their default value of 0, which could lead to runtime errors or unexpected behavior during agent initialization.
Consider one of the following approaches:
- Add validation in the constructor or via property setters
- Make them required constructor parameters
- Add XML documentation indicating they are required properties
Example validation approach:
public int ObservationSize { get; set; } public int ActionSize { get; set; } public MuZeroOptions() { var numOps = NumericOperations<T>.Instance; LearningRate = numOps.FromDouble(0.0001); DiscountFactor = numOps.FromDouble(0.997); } public void Validate() { if (ObservationSize <= 0) throw new ArgumentException("ObservationSize must be positive.", nameof(ObservationSize)); if (ActionSize <= 0) throw new ArgumentException("ActionSize must be positive.", nameof(ActionSize)); }
54-55: Consider documenting memory and prioritization implications.The
ReplayBufferSizeof 1 million andPriorityAlphaof 1.0 are reasonable defaults for MuZero, but might benefit from brief documentation:
ReplayBufferSize: At 1M transitions with potentially high-dimensional states (images, etc.), memory usage can be substantial. Users should be aware of this when configuring for resource-constrained environments.
PriorityAlpha: A value of 1.0 applies full prioritization (priorities are used as-is). The typical range is [0.0, 1.0], where 0.0 is uniform sampling and 1.0 is maximum prioritization. This aggressive default may not be optimal for all tasks.
Example documentation addition:
// Training parameters public int UnrollSteps { get; set; } = 5; // Number of steps to unroll for training public int TDSteps { get; set; } = 10; // TD bootstrap steps public int BatchSize { get; set; } = 256; public int ReplayBufferSize { get; set; } = 1000000; // 1M transitions; consider memory usage for high-dim states public double PriorityAlpha { get; set; } = 1.0; // Range [0.0, 1.0]: 0=uniform, 1=full prioritizationsrc/Models/Options/MADDPGOptions.cs (1)
51-58: Consider documenting or extracting default hyperparameter values.The hardcoded default values (0.0001, 0.001, 0.99, 0.001) are reasonable for MADDPG but could benefit from documentation explaining why these specific values were chosen, or extraction into named constants for clarity.
Example with constants:
+ // Default hyperparameters based on MADDPG paper recommendations + private const double DefaultActorLearningRate = 0.0001; + private const double DefaultCriticLearningRate = 0.001; + private const double DefaultDiscountFactor = 0.99; + private const double DefaultTargetUpdateTau = 0.001; + public MADDPGOptions() { var numOps = NumericOperations<T>.Instance; - ActorLearningRate = numOps.FromDouble(0.0001); - CriticLearningRate = numOps.FromDouble(0.001); - DiscountFactor = numOps.FromDouble(0.99); - TargetUpdateTau = numOps.FromDouble(0.001); + ActorLearningRate = numOps.FromDouble(DefaultActorLearningRate); + CriticLearningRate = numOps.FromDouble(DefaultCriticLearningRate); + DiscountFactor = numOps.FromDouble(DefaultDiscountFactor); + TargetUpdateTau = numOps.FromDouble(DefaultTargetUpdateTau); }src/Models/Options/WorldModelsOptions.cs (2)
1-1: Remove unused import.The
AiDotNet.LossFunctionsnamespace doesn't appear to be used in this file.-using AiDotNet.LossFunctions; - namespace AiDotNet.Models.Options;
33-62: Consider adding property validation.The configuration properties lack validation (e.g., positive values for dimensions, epochs; DiscountFactor in [0,1]; etc.). While validation may occur in
WorldModelsAgent, adding validation here would provide earlier error detection and clearer contracts.Example validation could be added via:
- Property setters with guards
- A
Validate()method called before use- Init-only properties with validation in constructor
src/Models/Options/DreamerOptions.cs (1)
33-55: Optional: Add XML documentation to properties.Individual properties lack XML documentation. Adding
<summary>tags would help users understand each parameter's purpose and valid ranges without needing to reference external documentation.Example for a few properties:
+ /// <summary> + /// The size of the observation space. + /// </summary> public int ObservationSize { get; set; } + /// <summary> + /// The size of the action space. + /// </summary> public int ActionSize { get; set; } + /// <summary> + /// Learning rate for training the world model and policy. Default: 0.0001 + /// </summary> public T LearningRate { get; set; }
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (34)
src/Models/Options/A2COptions.cs(1 hunks)src/Models/Options/A3COptions.cs(1 hunks)src/Models/Options/CQLOptions.cs(1 hunks)src/Models/Options/DDPGOptions.cs(1 hunks)src/Models/Options/DQNOptions.cs(1 hunks)src/Models/Options/DecisionTransformerOptions.cs(1 hunks)src/Models/Options/DoubleDQNOptions.cs(1 hunks)src/Models/Options/DreamerOptions.cs(1 hunks)src/Models/Options/DuelingDQNOptions.cs(1 hunks)src/Models/Options/IQLOptions.cs(1 hunks)src/Models/Options/MADDPGOptions.cs(1 hunks)src/Models/Options/MuZeroOptions.cs(1 hunks)src/Models/Options/PPOOptions.cs(1 hunks)src/Models/Options/QMIXOptions.cs(1 hunks)src/Models/Options/REINFORCEOptions.cs(1 hunks)src/Models/Options/RainbowDQNOptions.cs(1 hunks)src/Models/Options/SACOptions.cs(1 hunks)src/Models/Options/TD3Options.cs(1 hunks)src/Models/Options/TRPOOptions.cs(1 hunks)src/Models/Options/WorldModelsOptions.cs(1 hunks)src/ReinforcementLearning/Agents/A2C/A2CAgent.cs(1 hunks)src/ReinforcementLearning/Agents/A3C/A3CAgent.cs(1 hunks)src/ReinforcementLearning/Agents/CQL/CQLAgent.cs(1 hunks)src/ReinforcementLearning/Agents/DDPG/DDPGAgent.cs(1 hunks)src/ReinforcementLearning/Agents/DQN/DQNAgent.cs(1 hunks)src/ReinforcementLearning/Agents/DoubleDQN/DoubleDQNAgent.cs(1 hunks)src/ReinforcementLearning/Agents/DuelingDQN/DuelingDQNAgent.cs(1 hunks)src/ReinforcementLearning/Agents/IQL/IQLAgent.cs(1 hunks)src/ReinforcementLearning/Agents/PPO/PPOAgent.cs(1 hunks)src/ReinforcementLearning/Agents/REINFORCE/REINFORCEAgent.cs(1 hunks)src/ReinforcementLearning/Agents/SAC/SACAgent.cs(1 hunks)src/ReinforcementLearning/Agents/TD3/TD3Agent.cs(1 hunks)src/ReinforcementLearning/Agents/TRPO/TRPOAgent.cs(1 hunks)src/ReinforcementLearning/ReplayBuffers/PrioritizedReplayBuffer.cs(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (31)
src/ReinforcementLearning/Agents/DDPG/DDPGAgent.cs (4)
src/Models/Options/DDPGOptions.cs (2)
DDPGOptions(9-34)DDPGOptions(26-33)src/ReinforcementLearning/Agents/SAC/SACAgent.cs (11)
NeuralNetwork(98-122)NeuralNetwork(124-149)CopyNetworkWeights(573-576)SoftUpdateTargets(388-392)UpdateNetworkParameters(413-425)SoftUpdateNetwork(394-411)Serialize(459-482)Deserialize(485-506)Matrix(509-524)Matrix(560-564)SetParameters(527-549)src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs (1)
ReinforcementLearningOptions(316-382)src/Helpers/MathHelper.cs (1)
MathHelper(19-992)
src/Models/Options/DecisionTransformerOptions.cs (1)
src/Helpers/NeuralNetworkHelper.cs (1)
ILossFunction(49-76)
src/ReinforcementLearning/Agents/DoubleDQN/DoubleDQNAgent.cs (5)
src/ReinforcementLearning/Agents/DDPG/DDPGAgent.cs (16)
T(173-198)T(200-240)T(242-272)NeuralNetwork(92-115)NeuralNetwork(117-142)CopyNetworkWeights(447-450)Vector(145-164)Vector(313-321)Vector(475-492)Serialize(345-366)Deserialize(369-388)Matrix(391-404)Matrix(435-439)SetParameters(407-424)IFullModel(427-432)ApplyGradients(442-445)src/ReinforcementLearning/Agents/DQN/DQNAgent.cs (13)
T(151-229)T(411-424)NeuralNetwork(92-119)CopyNetworkWeights(388-392)Vector(122-141)ArgMax(394-409)Serialize(254-278)Deserialize(281-303)Matrix(306-315)Matrix(355-369)SetParameters(318-326)IFullModel(329-352)ApplyGradients(372-384)src/ReinforcementLearning/Agents/DuelingDQN/DuelingDQNAgent.cs (16)
T(125-188)CopyNetworkWeights(322-326)Vector(100-116)Vector(404-448)ArgMax(328-343)Backward(450-473)Serialize(213-234)Serialize(522-528)Deserialize(237-256)Deserialize(530-533)Matrix(259-262)Matrix(299-313)Matrix(492-515)SetParameters(265-268)IFullModel(271-296)ApplyGradients(316-319)src/Models/Options/DoubleDQNOptions.cs (2)
DoubleDQNOptions(9-32)DoubleDQNOptions(26-31)src/ReinforcementLearning/Agents/REINFORCE/REINFORCEAgent.cs (11)
NeuralNetwork(70-94)Vector(97-109)Vector(111-119)ArgMax(433-442)Serialize(316-329)Deserialize(332-343)Matrix(346-357)Matrix(379-383)SetParameters(360-368)IFullModel(371-376)ApplyGradients(386-389)
src/ReinforcementLearning/Agents/A3C/A3CAgent.cs (4)
src/Helpers/MathHelper.cs (2)
INumericOperations(36-66)MathHelper(19-992)src/ReinforcementLearning/Agents/REINFORCE/REINFORCEAgent.cs (1)
ComputeReturns(240-279)src/Helpers/StatisticsHelper.cs (1)
StatisticsHelper(17-6685)src/NeuralNetworks/Layers/DenseLayer.cs (1)
SetWeights(564-579)
src/Models/Options/DuelingDQNOptions.cs (3)
src/ReinforcementLearning/Agents/DQN/DQNAgent.cs (2)
T(151-229)T(411-424)src/ReinforcementLearning/Agents/DuelingDQN/DuelingDQNAgent.cs (1)
T(125-188)src/Helpers/NeuralNetworkHelper.cs (1)
ILossFunction(49-76)
src/Models/Options/DDPGOptions.cs (2)
src/ReinforcementLearning/Agents/DDPG/DDPGAgent.cs (3)
T(173-198)T(200-240)T(242-272)src/Helpers/NeuralNetworkHelper.cs (1)
ILossFunction(49-76)
src/Models/Options/MADDPGOptions.cs (1)
src/ReinforcementLearning/ReplayBuffers/PrioritizedReplayBuffer.cs (1)
List(52-108)
src/Models/Options/IQLOptions.cs (2)
src/ReinforcementLearning/Agents/CQL/CQLAgent.cs (4)
T(180-211)T(213-274)T(276-307)T(309-344)src/Helpers/NeuralNetworkHelper.cs (1)
ILossFunction(49-76)
src/Models/Options/DoubleDQNOptions.cs (3)
src/ReinforcementLearning/Agents/DQN/DQNAgent.cs (2)
T(151-229)T(411-424)src/ReinforcementLearning/Agents/DoubleDQN/DoubleDQNAgent.cs (1)
T(137-207)src/Helpers/NeuralNetworkHelper.cs (1)
ILossFunction(49-76)
src/Models/Options/DQNOptions.cs (2)
src/ReinforcementLearning/Agents/DQN/DQNAgent.cs (2)
T(151-229)T(411-424)src/Helpers/NeuralNetworkHelper.cs (1)
ILossFunction(49-76)
src/Models/Options/MuZeroOptions.cs (1)
src/ReinforcementLearning/ReplayBuffers/PrioritizedReplayBuffer.cs (1)
List(52-108)
src/Models/Options/RainbowDQNOptions.cs (2)
src/ReinforcementLearning/Agents/TD3/TD3Agent.cs (2)
T(151-175)T(177-241)src/Helpers/NeuralNetworkHelper.cs (1)
ILossFunction(49-76)
src/ReinforcementLearning/Agents/TRPO/TRPOAgent.cs (3)
src/Helpers/MathHelper.cs (2)
INumericOperations(36-66)MathHelper(19-992)src/Helpers/StatisticsHelper.cs (1)
StatisticsHelper(17-6685)src/NeuralNetworks/Layers/DenseLayer.cs (1)
SetWeights(564-579)
src/ReinforcementLearning/Agents/DQN/DQNAgent.cs (4)
src/ReinforcementLearning/Agents/DDPG/DDPGAgent.cs (19)
T(173-198)T(200-240)T(242-272)NeuralNetwork(92-115)NeuralNetwork(117-142)CopyNetworkWeights(447-450)Vector(145-164)Vector(313-321)Vector(475-492)StoreExperience(167-170)Dictionary(324-330)ModelMetadata(333-342)Serialize(345-366)Deserialize(369-388)Matrix(391-404)Matrix(435-439)SetParameters(407-424)IFullModel(427-432)ApplyGradients(442-445)src/ReinforcementLearning/Agents/DoubleDQN/DoubleDQNAgent.cs (15)
T(137-207)NeuralNetwork(87-109)CopyNetworkWeights(359-362)Vector(112-128)ArgMax(364-379)StoreExperience(131-134)Dictionary(210-217)ModelMetadata(220-229)Serialize(232-253)Deserialize(256-275)Matrix(278-287)Matrix(327-341)SetParameters(290-298)IFullModel(301-324)ApplyGradients(344-356)src/Models/Options/DQNOptions.cs (2)
DQNOptions(9-32)DQNOptions(26-31)src/ReinforcementLearning/ReplayBuffers/UniformReplayBuffer.cs (5)
UniformReplayBuffer(18-96)UniformReplayBuffer(35-44)Add(47-59)List(62-82)CanSample(85-88)
src/ReinforcementLearning/ReplayBuffers/PrioritizedReplayBuffer.cs (2)
src/ReinforcementLearning/ReplayBuffers/UniformReplayBuffer.cs (2)
List(62-82)Add(47-59)src/Helpers/MathHelper.cs (1)
INumericOperations(36-66)
src/Models/Options/CQLOptions.cs (1)
src/ReinforcementLearning/Agents/CQL/CQLAgent.cs (4)
T(180-211)T(213-274)T(276-307)T(309-344)
src/Models/Options/TRPOOptions.cs (1)
src/Helpers/NeuralNetworkHelper.cs (1)
ILossFunction(49-76)
src/ReinforcementLearning/Agents/IQL/IQLAgent.cs (2)
src/Helpers/MathHelper.cs (2)
INumericOperations(36-66)MathHelper(19-992)src/NeuralNetworks/Layers/DenseLayer.cs (1)
SetWeights(564-579)
src/ReinforcementLearning/Agents/SAC/SACAgent.cs (5)
src/ReinforcementLearning/Agents/A2C/A2CAgent.cs (7)
T(170-201)T(204-266)T(341-372)Vector(113-125)Vector(127-135)Vector(137-160)Vector(487-512)src/ReinforcementLearning/Agents/DDPG/DDPGAgent.cs (10)
T(173-198)T(200-240)T(242-272)CopyNetworkWeights(447-450)Vector(145-164)Vector(313-321)Vector(475-492)SoftUpdateTargets(274-278)UpdateNetworkParameters(299-311)SoftUpdateNetwork(280-297)src/Models/Options/SACOptions.cs (2)
SACOptions(30-166)SACOptions(156-165)src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs (1)
ReinforcementLearningOptions(316-382)src/Helpers/MathHelper.cs (1)
MathHelper(19-992)
src/ReinforcementLearning/Agents/REINFORCE/REINFORCEAgent.cs (4)
src/ReinforcementLearning/Agents/A2C/A2CAgent.cs (13)
T(170-201)T(204-266)T(341-372)NeuralNetwork(67-88)NeuralNetwork(90-110)Vector(113-125)Vector(127-135)Vector(137-160)Vector(487-512)SampleCategorical(514-526)ArgMax(528-537)StoreExperience(163-168)UpdatePolicyNetwork(313-325)src/Models/Options/REINFORCEOptions.cs (2)
REINFORCEOptions(25-41)REINFORCEOptions(35-40)src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs (1)
ReinforcementLearningOptions(316-382)src/ReinforcementLearning/Common/Trajectory.cs (1)
AddStep(85-93)
src/Models/Options/SACOptions.cs (5)
src/ReinforcementLearning/Agents/A2C/A2CAgent.cs (3)
T(170-201)T(204-266)T(341-372)src/ReinforcementLearning/Agents/DDPG/DDPGAgent.cs (3)
T(173-198)T(200-240)T(242-272)src/ReinforcementLearning/Agents/DQN/DQNAgent.cs (2)
T(151-229)T(411-424)src/ReinforcementLearning/Agents/PPO/PPOAgent.cs (6)
T(223-235)T(237-245)T(247-271)T(274-314)T(377-436)T(438-476)src/Helpers/NeuralNetworkHelper.cs (1)
ILossFunction(49-76)
src/ReinforcementLearning/Agents/DuelingDQN/DuelingDQNAgent.cs (5)
src/ReinforcementLearning/Agents/DDPG/DDPGAgent.cs (11)
T(173-198)T(200-240)T(242-272)CopyNetworkWeights(447-450)Vector(145-164)Vector(313-321)Vector(475-492)Serialize(345-366)Deserialize(369-388)Matrix(391-404)Matrix(435-439)src/ReinforcementLearning/Agents/DQN/DQNAgent.cs (9)
T(151-229)T(411-424)CopyNetworkWeights(388-392)Vector(122-141)ArgMax(394-409)Serialize(254-278)Deserialize(281-303)Matrix(306-315)Matrix(355-369)src/ReinforcementLearning/Agents/DoubleDQN/DoubleDQNAgent.cs (8)
T(137-207)CopyNetworkWeights(359-362)Vector(112-128)ArgMax(364-379)Serialize(232-253)Deserialize(256-275)Matrix(278-287)Matrix(327-341)src/Models/Options/DuelingDQNOptions.cs (2)
DuelingDQNOptions(9-34)DuelingDQNOptions(28-33)src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs (1)
ReinforcementLearningOptions(316-382)
src/ReinforcementLearning/Agents/PPO/PPOAgent.cs (4)
src/ReinforcementLearning/Agents/A2C/A2CAgent.cs (15)
T(170-201)T(204-266)T(341-372)NeuralNetwork(67-88)NeuralNetwork(90-110)Vector(113-125)Vector(127-135)Vector(137-160)Vector(487-512)SampleCategorical(514-526)ArgMax(528-537)StoreExperience(163-168)ComputeAdvantages(268-311)UpdatePolicyNetwork(313-325)UpdateValueNetwork(327-339)src/Models/Options/PPOOptions.cs (2)
PPOOptions(29-170)PPOOptions(159-169)src/ReinforcementLearning/Common/Trajectory.cs (4)
Trajectory(27-114)Trajectory(72-80)AddStep(85-93)Clear(103-113)src/Helpers/MathHelper.cs (1)
MathHelper(19-992)
src/ReinforcementLearning/Agents/CQL/CQLAgent.cs (3)
src/Models/Options/CQLOptions.cs (2)
CQLOptions(29-69)CQLOptions(57-68)src/Helpers/MathHelper.cs (2)
INumericOperations(36-66)MathHelper(19-992)src/NeuralNetworks/Layers/DenseLayer.cs (1)
SetWeights(564-579)
src/Models/Options/REINFORCEOptions.cs (2)
src/ReinforcementLearning/Agents/A2C/A2CAgent.cs (3)
T(170-201)T(204-266)T(341-372)src/ReinforcementLearning/Agents/PPO/PPOAgent.cs (6)
T(223-235)T(237-245)T(247-271)T(274-314)T(377-436)T(438-476)
src/Models/Options/TD3Options.cs (2)
src/ReinforcementLearning/Agents/DDPG/DDPGAgent.cs (3)
T(173-198)T(200-240)T(242-272)src/Helpers/NeuralNetworkHelper.cs (1)
ILossFunction(49-76)
src/Models/Options/A3COptions.cs (1)
src/Helpers/NeuralNetworkHelper.cs (1)
ILossFunction(49-76)
src/Models/Options/A2COptions.cs (2)
src/ReinforcementLearning/Agents/A2C/A2CAgent.cs (3)
T(170-201)T(204-266)T(341-372)src/Helpers/NeuralNetworkHelper.cs (1)
ILossFunction(49-76)
src/Models/Options/PPOOptions.cs (2)
src/ReinforcementLearning/Agents/PPO/PPOAgent.cs (6)
T(223-235)T(237-245)T(247-271)T(274-314)T(377-436)T(438-476)src/Helpers/NeuralNetworkHelper.cs (1)
ILossFunction(49-76)
src/ReinforcementLearning/Agents/A2C/A2CAgent.cs (9)
src/ReinforcementLearning/Agents/DDPG/DDPGAgent.cs (12)
T(173-198)T(200-240)T(242-272)Vector(145-164)Vector(313-321)Vector(475-492)StoreExperience(167-170)ModelMetadata(333-342)Serialize(345-366)Deserialize(369-388)SetParameters(407-424)ApplyGradients(442-445)src/ReinforcementLearning/Agents/DQN/DQNAgent.cs (9)
T(151-229)Vector(122-141)ArgMax(394-409)StoreExperience(144-148)ModelMetadata(242-251)Serialize(254-278)Deserialize(281-303)SetParameters(318-326)ApplyGradients(372-384)src/ReinforcementLearning/Agents/PPO/PPOAgent.cs (22)
T(223-235)T(237-245)T(247-271)T(274-314)T(377-436)T(438-476)T(713-722)Vector(144-158)Vector(160-181)Vector(183-208)Vector(666-685)SampleCategorical(687-700)ArgMax(702-711)StoreExperience(211-221)ComputeAdvantages(316-375)UpdatePolicyNetwork(478-512)UpdateValueNetwork(514-542)ModelMetadata(553-562)Serialize(565-583)Deserialize(586-602)SetParameters(623-639)ApplyGradients(660-663)src/ReinforcementLearning/Agents/REINFORCE/REINFORCEAgent.cs (15)
T(154-188)T(191-238)Vector(97-109)Vector(111-119)Vector(121-144)Vector(392-417)SampleCategorical(419-431)ArgMax(433-442)StoreExperience(147-152)UpdatePolicyNetwork(281-293)ModelMetadata(304-313)Serialize(316-329)Deserialize(332-343)SetParameters(360-368)ApplyGradients(386-389)src/Models/Options/A2COptions.cs (2)
A2COptions(25-50)A2COptions(41-49)src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs (1)
ReinforcementLearningOptions(316-382)src/Helpers/MathHelper.cs (1)
MathHelper(19-992)src/ReinforcementLearning/Common/Trajectory.cs (1)
AddStep(85-93)src/Helpers/StatisticsHelper.cs (1)
StatisticsHelper(17-6685)
src/ReinforcementLearning/Agents/TD3/TD3Agent.cs (2)
src/Helpers/MathHelper.cs (2)
INumericOperations(36-66)MathHelper(19-992)src/NeuralNetworks/Layers/DenseLayer.cs (1)
SetWeights(564-579)
🔇 Additional comments (33)
src/Models/Options/MuZeroOptions.cs (2)
61-66: LGTM! Correct use of NumericOperations pattern.The constructor properly initializes generic numeric properties using the
NumericOperations<T>pattern, which is consistent with other option classes in the project. The default learning rate (0.0001) and discount factor (0.997) are appropriate for MuZero, which is designed for complex, long-horizon tasks.
5-30: Excellent documentation and clear structure.The XML documentation provides comprehensive coverage of MuZero concepts with both technical details and beginner-friendly explanations. The use of concrete examples (Atari, Go, Chess) and analogies effectively communicates the algorithm's capabilities and innovations.
The class structure is well-organized with logical grouping of related parameters (network architecture, MCTS, training), making it easy to understand and configure.
src/Models/Options/DecisionTransformerOptions.cs (4)
1-3: LGTM: Imports and namespace are correct.The namespace follows the project conventions, and the import for LossFunctions aligns with the properties used in the class.
5-30: LGTM: Excellent documentation.The XML documentation provides comprehensive coverage of the Decision Transformer concept, including both technical details and beginner-friendly explanations. This will help users understand when and how to use this approach.
40-40: Verify single attention head default.The default of
NumHeads = 1is atypical for transformer architectures, which usually benefit from multi-head attention (commonly 4, 8, or more heads). While this may be intentional for simplicity or computational efficiency in RL contexts, please verify this choice aligns with the intended Decision Transformer implementation.If multi-head attention is desired, consider:
-public int NumHeads { get; set; } = 1; +public int NumHeads { get; set; } = 4;
50-54: LGTM: Learning rate initialization is appropriate.The constructor correctly initializes
LearningRateusingNumericOperations<T>to handle generic numeric types. The default value of 0.0001 is appropriate for transformer-based architectures.src/Models/Options/WorldModelsOptions.cs (3)
5-31: Excellent documentation!The XML documentation is thorough and the beginner-friendly explanation of World Models architecture is particularly helpful. The breakdown of VAE, MDN-RNN, and Controller components clearly communicates the model's design.
64-69: LGTM - consistent initialization pattern.The constructor correctly initializes generic numeric properties using
NumericOperations<T>, consistent with other option classes in the codebase. The default values (0.001 for learning rate, 0.99 for discount factor) are reasonable for RL agents.
36-36: Dismiss the review comment – ActionSize default is intentional and follows codebase patterns.The
ActionSizeproperty has no explicit default because it is a required, caller-provided configuration parameter, not a generic default. This design matches every RL agent option class in the codebase: DQNOptions, PPOOptions, SACOptions, TD3Options, IQLOptions, CQLOptions, A3COptions, and others all use the same pattern—ActionSize has no default value and is environment-specific. LikeStateSize,ActionSizeis a mandatory parameter that must always be explicitly set by the caller before use; there is no sensible universal default for action space dimensionality. The codebase reflects this by never setting it to a default, and all usage sites access it directly without null-checks or validation, confirming that this is the intended design.Likely an incorrect or invalid review comment.
src/Models/Options/SACOptions.cs (1)
1-166: LGTM! Well-structured SAC configuration.The options class provides comprehensive SAC hyperparameters with sensible defaults (learning rates 0.0003, discount 0.99, tau 0.005, temperature 0.2) and clear documentation. The constructor properly initializes numeric types via
NumericOperations<T>, ensuring type safety across generic numeric types.src/ReinforcementLearning/Agents/SAC/SACAgent.cs (15)
62-96: LGTM! Constructor properly initializes SAC components.The constructor correctly maps SAC-specific options to the base class, builds all five networks (actor, two critics, two target critics), initializes temperature in log-space, and synchronizes target networks. Network registration ensures proper lifecycle management.
98-149: LGTM! Network architectures follow SAC specifications.Policy network correctly outputs
ActionSize * 2for Gaussian mean and log_std parameters. Q-networks properly accept concatenated state-action inputs and output scalar Q-values. Configurable hidden layers provide flexibility.
152-205: LGTM! Action selection correctly implements squashed Gaussian policy.The implementation properly handles:
- Stochastic sampling during training with reparameterization trick
- Deterministic mean for inference
- Tanh squashing to bound actions in [-1, 1]
- Change-of-variables correction in log-probability:
log p(a) = log p(z) - log(1 - tanh²(z))- Numerical stability via log_std clamping to [-20, 2]
208-211: LGTM!Experience storage correctly delegates to the replay buffer.
214-253: LGTM! Training loop follows SAC algorithm structure.The method correctly:
- Gates training on warmup steps and replay buffer readiness
- Performs configurable gradient steps per environment step
- Updates critics → actor → temperature in proper sequence
- Soft-updates target networks
- Tracks average loss across gradient steps
255-319: LGTM! Critic updates implement SAC's twin Q-learning with entropy.The method correctly:
- Samples next actions from the current policy
- Uses
min(Q1_target, Q2_target)to reduce overestimation bias- Applies entropy regularization:
target = minQ - alpha * log_prob- Computes Bellman targets with proper terminal state handling
- Updates both Q-networks via loss computation, backpropagation, and parameter updates
361-386: LGTM! Temperature auto-tuning implements SAC's entropy constraint.The method correctly:
- Computes average policy entropy across the batch
- Uses target entropy (defaulting to
-ActionSizefor continuous actions)- Computes alpha loss:
L = -α * (log_π + H_target)- Updates log_alpha via gradient descent:
log_α ← log_α - lr * ∂L/∂log_αThis ensures the policy maintains the desired entropy level automatically.
388-411: LGTM! Soft target updates use correct Polyak averaging.The implementation properly applies:
θ_target ← τ * θ_source + (1 - τ) * θ_targetfor gradual target network updates, which stabilizes training.
413-425: LGTM! Parameter updates apply standard SGD.The method correctly retrieves flattened parameters and gradients, applies the update rule
θ ← θ - lr * ∇θ, and updates the network.
427-435: LGTM!The helper correctly concatenates state and action vectors for Q-network input.
438-456: LGTM! Metrics and metadata expose SAC-specific information.
GetMetricsappropriately adds current temperature (Alpha) and replay buffer size.GetModelMetadataprovides complete model information including type, feature count, and training samples.
459-506: LGTM! Serialization captures complete agent state.The methods correctly serialize/deserialize state size, action size, log_alpha, and all five networks (policy, Q1, Q2, Q1_target, Q2_target) in a symmetrical manner.
509-549: LGTM! Parameter management correctly aggregates networks.
GetParametersconcatenates all trainable network parameters.SetParametersproperly distributes parameters back and synchronizes target networks to maintain consistency.
552-570: LGTM! Clone works correctly; gradient methods are interface stubs.
Cloneproperly creates a new agent and copies parameters.ComputeGradientsandApplyGradientsare no-op stubs, which is appropriate since SAC uses its own actor-critic update logic rather than the standard supervised learning gradient interface.
573-576: LGTM!The helper correctly performs hard copies of network weights for target network initialization.
src/Models/Options/DreamerOptions.cs (2)
5-30: Excellent documentation quality.The XML documentation is comprehensive and well-structured, providing both technical details and beginner-friendly explanations of the Dreamer algorithm. This is valuable for library users at different experience levels.
57-62: Appropriate initialization pattern for generic numeric types.The use of
NumericOperations<T>.Instance.FromDouble()correctly handles initialization of generic numeric properties. The default values (0.0001 for learning rate, 0.99 for discount factor) are sensible defaults for RL algorithms.Note: The constructor's lack of
ObservationSizeandActionSizeinitialization is already addressed in a separate comment.src/ReinforcementLearning/Agents/DuelingDQN/DuelingDQNAgent.cs (6)
125-188: Training loop depends on incomplete network implementation.The Train method follows the correct Double DQN pattern and is well-structured. However, lines 171 and 174 call
_qNetwork.Backward()and_qNetwork.UpdateWeights(), which are incomplete stubs in the DuelingNetwork class (see separate comments below). Until those methods are implemented, this training loop cannot actually update network weights.
213-234: Serialization depends on incomplete network Serialize method.The agent-level serialization logic is well-structured, but lines 225 and 229 call
_qNetwork.Serialize()and_targetNetwork.Serialize(), which are stubs (see comment on lines 522-533). Until network serialization is implemented, this method will produce incomplete data.
237-256: Deserialization depends on incomplete network Deserialize method.Lines 251 and 255 call
_qNetwork.Deserialize()and_targetNetwork.Deserialize(), which are stubs (see comment on lines 522-533). Until network deserialization is implemented, this method cannot restore saved models.
259-268: Parameter access depends on incomplete network methods.These methods delegate to
GetFlattenedParameters()andSetFlattenedParameters()in DuelingNetwork, which are stubs (see separate comments below). This breaks parameter inspection, cloning, and external parameter injection.
271-296: Clone depends on incomplete SetParameters method.Line 294 calls
SetParameters(GetParameters())to copy network weights, but both methods depend on incomplete DuelingNetwork implementations (see comment on lines 259-268). Until parameter serialization is fixed, cloning will fail.Additionally, line 280 sets
EpsilonStart = _epsilon, meaning the clone inherits the current (decayed) epsilon rather than the original starting value. This may be intentional for creating snapshots, but verify this matches your intended clone semantics.
404-448: LGTM: Forward pass correctly implements dueling architecture.The forward pass properly implements the dueling DQN formula: Q(s,a) = V(s) + (A(s,a) - mean(A(s,:))). The mean advantage centering (lines 432-437) and Q-value combination (lines 440-445) are correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 15
🧹 Nitpick comments (4)
src/ReinforcementLearning/Agents/Dreamer/DreamerAgent.cs (4)
139-144: Consider making exploration noise configurable.The exploration noise standard deviation is hardcoded at 0.1. Different tasks may benefit from different exploration strategies. Consider adding an
ExplorationNoiseStdDevparameter toDreamerOptions<T>to allow tuning without code changes.Apply this diff to make exploration noise configurable:
- var noise = MathHelper.GetNormalRandom<T>(_numOps.Zero, _numOps.FromDouble(0.1)); + var noise = MathHelper.GetNormalRandom<T>(_numOps.Zero, _options.ExplorationNoiseStdDev);Then add the corresponding property to
DreamerOptions<T>:// In DreamerOptions<T> public T ExplorationNoiseStdDev { get; set; } // with default value in constructor
304-307: Consider making continue probability threshold configurable.The continue probability threshold is hardcoded at 0.5 (line 304). Different environments or tasks may benefit from different thresholds for episode termination detection.
- if (_numOps.Compare(continueProb, _numOps.FromDouble(0.5)) < 0) + if (_numOps.Compare(continueProb, _options.ContinueProbabilityThreshold) < 0)Add
ContinueProbabilityThresholdtoDreamerOptions<T>with a default of 0.5.
327-334: Metrics could include additional training indicators.Current metrics (update count, buffer size) are minimal. Consider adding:
- Average world model loss
- Average policy loss
- Exploration rate or noise level
- Average imagined return
These would provide better visibility into training progress and help with debugging.
313-325: ReplaceConcatenateVectorswith built-inVector<T>.Concatenatemethod.The
Vector<T>class provides a staticConcatenatemethod that accepts multiple vectors via a params array. Replace the private helper with direct calls:var dynamicsInput = Vector<T>.Concatenate(latentState, experience.action); var dynamicsInput = Vector<T>.Concatenate(latentState, action);Then remove the
ConcatenateVectorsmethod entirely. This improves consistency with the Vector API and eliminates redundant code.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
src/ReinforcementLearning/Agents/DecisionTransformer/DecisionTransformerAgent.cs(1 hunks)src/ReinforcementLearning/Agents/Dreamer/DreamerAgent.cs(1 hunks)src/ReinforcementLearning/Agents/MADDPG/MADDPGAgent.cs(1 hunks)src/ReinforcementLearning/Agents/MuZero/MuZeroAgent.cs(1 hunks)src/ReinforcementLearning/Agents/QMIX/QMIXAgent.cs(1 hunks)src/ReinforcementLearning/Agents/Rainbow/RainbowDQNAgent.cs(1 hunks)src/ReinforcementLearning/Agents/WorldModels/WorldModelsAgent.cs(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (7)
src/ReinforcementLearning/Agents/MuZero/MuZeroAgent.cs (3)
src/ReinforcementLearning/Agents/Dreamer/DreamerAgent.cs (8)
T(155-173)T(175-234)T(236-281)T(283-311)Vector(128-148)Vector(313-325)Vector(341-344)InitializeReplayBuffer(123-126)src/ReinforcementLearning/Agents/MADDPG/MADDPGAgent.cs (9)
T(199-223)T(225-260)T(262-304)Vector(137-158)Vector(160-164)Vector(360-381)Vector(383-395)Vector(411-414)InitializeReplayBuffer(129-132)src/ReinforcementLearning/Agents/QMIX/QMIXAgent.cs (8)
T(197-281)T(412-423)Vector(135-158)Vector(160-164)Vector(320-332)Vector(334-355)InitializeReplayBuffer(127-130)ArgMax(395-410)
src/ReinforcementLearning/Agents/WorldModels/WorldModelsAgent.cs (1)
src/Helpers/MathHelper.cs (2)
INumericOperations(36-66)MathHelper(19-992)
src/ReinforcementLearning/Agents/QMIX/QMIXAgent.cs (4)
src/ReinforcementLearning/Agents/Dreamer/DreamerAgent.cs (10)
T(155-173)T(175-234)T(236-281)T(283-311)NeuralNetwork(88-103)NeuralNetwork(105-121)InitializeReplayBuffer(123-126)Vector(128-148)Vector(313-325)Vector(341-344)src/ReinforcementLearning/Agents/MADDPG/MADDPGAgent.cs (12)
T(199-223)T(225-260)T(262-304)NeuralNetwork(92-108)NeuralNetwork(110-127)InitializeReplayBuffer(129-132)CopyNetworkWeights(345-358)Vector(137-158)Vector(160-164)Vector(360-381)Vector(383-395)Vector(411-414)src/ReinforcementLearning/Agents/MuZero/MuZeroAgent.cs (9)
T(294-297)T(304-358)NeuralNetwork(87-102)InitializeReplayBuffer(104-107)Vector(109-130)Vector(284-292)Vector(360-372)Vector(405-408)ArgMax(374-389)src/NeuralNetworks/Layers/DenseLayer.cs (1)
SetWeights(564-579)
src/ReinforcementLearning/Agents/Dreamer/DreamerAgent.cs (3)
src/ReinforcementLearning/Agents/MADDPG/MADDPGAgent.cs (10)
T(199-223)T(225-260)T(262-304)InitializeReplayBuffer(129-132)Vector(137-158)Vector(160-164)Vector(360-381)Vector(383-395)Vector(411-414)StoreExperience(193-197)src/ReinforcementLearning/Agents/MuZero/MuZeroAgent.cs (8)
T(294-297)T(304-358)InitializeReplayBuffer(104-107)Vector(109-130)Vector(284-292)Vector(360-372)Vector(405-408)StoreExperience(299-302)src/Helpers/MathHelper.cs (2)
INumericOperations(36-66)MathHelper(19-992)
src/ReinforcementLearning/Agents/Rainbow/RainbowDQNAgent.cs (3)
src/Helpers/MathHelper.cs (1)
INumericOperations(36-66)src/ReinforcementLearning/ReplayBuffers/PrioritizedReplayBuffer.cs (1)
UpdatePriorities(110-118)src/NeuralNetworks/Layers/DenseLayer.cs (1)
SetWeights(564-579)
src/ReinforcementLearning/Agents/MADDPG/MADDPGAgent.cs (4)
src/ReinforcementLearning/Agents/QMIX/QMIXAgent.cs (11)
T(197-281)T(412-423)List(283-318)NeuralNetwork(90-105)NeuralNetwork(107-125)CopyNetworkWeights(380-393)Vector(135-158)Vector(160-164)Vector(320-332)Vector(334-355)Vector(357-378)src/Helpers/MathHelper.cs (2)
INumericOperations(36-66)MathHelper(19-992)src/ReinforcementLearning/Agents/TD3/TD3Agent.cs (1)
UpdateActor(243-274)src/NeuralNetworks/Layers/DenseLayer.cs (1)
SetWeights(564-579)
src/ReinforcementLearning/Agents/DecisionTransformer/DecisionTransformerAgent.cs (1)
src/ReinforcementLearning/Agents/Dreamer/DreamerAgent.cs (11)
T(155-173)T(175-234)T(236-281)T(283-311)NeuralNetwork(88-103)NeuralNetwork(105-121)Vector(128-148)Vector(313-325)Vector(341-344)StoreExperience(150-153)Dictionary(327-334)
🔇 Additional comments (1)
src/ReinforcementLearning/Agents/Dreamer/DreamerAgent.cs (1)
336-339: Empty ResetEpisode implementation is acceptable for now.The current stateless design doesn't require episode-specific resets. If future enhancements add episode-level state (e.g., eligibility traces, episode buffers, or recurrent hidden states), this method will need implementation.
src/ReinforcementLearning/Agents/DecisionTransformer/DecisionTransformerAgent.cs
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
♻️ Duplicate comments (24)
src/ReinforcementLearning/Agents/Rainbow/RainbowDQNAgent.cs (2)
109-109: Fix the value-stream to head wiring (past issue still present).Line 109 and Line 114 still use
previousSizefor the head input, but onceValueStreamLayersrun, the preceding output width isvalueSize. As soon as any value-stream layer is configured, the dense head will receive tensors of the wrong size and throw at the first forward pass. UpdatepreviousSizeafter the value-stream loop before adding the head.Apply this diff:
foreach (var layerSize in _options.ValueStreamLayers) { network.AddLayer(new DenseLayer<T>(valueSize, layerSize)); network.AddLayer(new ActivationLayer<T>(new ReLU<T>())); valueSize = layerSize; } + previousSize = valueSize; + if (_options.UseDistributional) { // Distributional RL: output atoms for each action network.AddLayer(new DenseLayer<T>(previousSize, _options.ActionSize * _options.NumAtoms));Also applies to: 114-114
296-299: Match gradient dimensionality to the distributional head (past issue still present).With
UseDistributional = true, the network output has_options.ActionSize * _options.NumAtomsunits, yet Line 296 still allocates a gradient of length_options.ActionSize._onlineNetwork.Backwardwill fail with a shape mismatch, so distributional training currently crashes.Apply this diff to fix the gradient allocation:
- var gradient = new Vector<T>(_options.ActionSize); - gradient[actionIndex] = tdError; + var gradientSize = _options.UseDistributional + ? _options.ActionSize * _options.NumAtoms + : _options.ActionSize; + var gradient = new Vector<T>(gradientSize); + + if (_options.UseDistributional) + { + int offset = actionIndex * _options.NumAtoms; + for (int atom = 0; atom < _options.NumAtoms; atom++) + { + gradient[offset + atom] = tdError; + } + } + else + { + gradient[actionIndex] = tdError; + }src/ReinforcementLearning/Agents/A3C/A3CAgent.cs (3)
199-205: Stop sharing one environment instance across worker tasks.Task.Run(() => RunWorker(environment, ...))hands the sameenvironmentto every worker, soReset/Stepexecute concurrently without synchronization. Typical envs here (e.g., CartPole) are not thread-safe, leading to corrupted episode state and undefined training. Give each worker its own environment instance (factory/clone) or run workers sequentially as the comment suggests.
258-263: Use the local policy output you just computed. ReturningSelectAction(state, training)locks and samples from_globalPolicyNetwork, so local weights never drive exploration or trajectories. Derive the action directly frompolicyOutput(mirroring the global logic) and return it to keep worker-local learning meaningful.
327-353: Compute real policy/value output gradients before backprop. FillingpolicyGradientwithadvantage * 0.1(andvalueGradientwith the raw TD error) ignores the sampled action, log-probs, and network outputs, so_globalPolicyNetwork.Backwardreceives meaningless signals. The agent won’t learn. Build the output-space gradients from the actual policy/value predictions (e.g., ∇log π(a|s)·advantage and d/dy ½(return−value)²) before callingBackward/UpdateWeights.src/ReinforcementLearning/Agents/WorldModels/WorldModelsAgent.cs (2)
268-278: Fix missing VAE encoder gradients.
_vaeDecoder.Backward(...)returns the latent gradient, but you discard it and call_vaeEncoder.UpdateWeights(...)without ever supplying a backward pass. The encoder (and KL term) therefore receive zero signal, so the latent posterior never learns. Wire the decoder gradient (plus KL derivative) through the reparameterization path before updating the encoder.- _vaeDecoder.Backward(gradient); - _vaeDecoder.UpdateWeights(_options.LearningRate); - - _vaeEncoder.UpdateWeights(_options.LearningRate); + var decoderGrad = _vaeDecoder.Backward(gradient); + _vaeDecoder.UpdateWeights(_options.LearningRate); + + var latentGrad = BuildLatentGradient(decoderGrad, latentMean, latentLogVar); + _vaeEncoder.Backward(latentGrad); + _vaeEncoder.UpdateWeights(_options.LearningRate);Implement
BuildLatentGradientto inject the KL derivatives for μ and logσ².
295-324: Don’t reuse the live hidden state for replay samples.During RNN training every sampled transition is concatenated with the agent’s current
_rnnHiddenState, which almost never matches the timestep the replay entry came from. That feeds stale context into each example, preventing the dynamics model from converging. Persist the hidden state (or sample contiguous sequences and rebuild it by unrolling) alongside each experience and use that stored state instead of_rnnHiddenStatewhen formingrnnInput.src/ReinforcementLearning/Agents/TRPO/TRPOAgent.cs (5)
45-46: Trajectory buffer still missingnextStatefieldThe past review comment correctly identified that
nextStateis dropped from the trajectory buffer, but this has not been fixed. The tuple still only contains(state, action, reward, done)instead of(state, action, reward, nextState, done). This causes incorrect bootstrapping in both Monte Carlo returns (lines 229-240) and GAE advantage computation (lines 246-261).
178-187:nextStateparameter is accepted but never storedLine 180 drops the
nextStateparameter when adding to the trajectory buffer. This was flagged in the previous review but remains unfixed.
229-240: Monte Carlo returns lack proper bootstrappingThe return calculation doesn't bootstrap from the value of
nextStatewhendone=false. Instead, it uses pure Monte Carlo returns across the entire buffer. BecausenextStateis missing from the trajectory buffer (see previous comments), the code cannot compute the correct bootstrapped returns at episode boundaries. This was flagged in the previous review and remains unfixed.
246-261: GAE advantage computation uses wrong value for bootstrappingLine 248 uses
values[i + 1]as the next value estimate, but this is the value of the state at transition i+1, not the value ofnextStatefrom transition i. Because the trajectory buffer doesn't storenextState, the code cannot compute the correct TD error for GAE. This is part of the broader issue flagged in the previous review.
304-327: Policy gradient computation still ignores recorded actionsThe policy update loop (lines 318-322) creates a gradient where every output neuron receives
advantage * 0.01, completely ignoring which action was actually taken. This was flagged as a critical issue in the previous review and remains unfixed. TRPO requires computing the score function gradient∇_θ log π_θ(a|s)for each stored action, not a uniform gradient across all outputs.src/ReinforcementLearning/Agents/QMIX/QMIXAgent.cs (2)
103-121: Enforce QMIX monotonicity in the mixer.The plain
DenseLayerstack allows negative weights, breaking QMIX's monotonicity constraint. QMIX requires the mixing network to use non-negative weights (typically via hypernetworks with SoftPlus/ReLU outputs) so that decentralized execution remains valid.
243-263: Fix TD gradient flow through mixer and agents.The gradient computation has two issues:
mixingGradient[0] = tdErrormisses the-2factor from the squared TD loss derivative- Agent gradients are divided evenly instead of using the mixer's partial derivatives (Jacobian) with respect to each agent's Q-value
Without the mixer's input gradients, agent updates are mathematically incorrect and the algorithm will not converge properly.
src/ReinforcementLearning/Agents/MADDPG/MADDPGAgent.cs (2)
221-256: Fix target Q computation to use target actors.Line 228 uses the stored
experience.actionwithexperience.nextState, which is incorrect for MADDPG. The TD target should use the next joint action produced by rolling the target actors forward on the next state, not the stored action.Additionally, line 250 sets
gradient[0] = errorbut is missing the-2factor from the derivative of the squared TD error loss.
258-300: Propagate critic gradients instead of fabricating actor gradients.Lines 288-293 create a constant gradient vector by dividing the Q-value evenly across action dimensions. This does not implement the deterministic policy gradient—no gradient from the critic ever reaches the actor parameters. MADDPG requires backpropagating the critic's derivative w.r.t. each agent's action to update the actor.
src/ReinforcementLearning/Agents/Dreamer/DreamerAgent.cs (4)
175-227: Critical: representation encoder never trained.The world-model loop only updates the dynamics, reward, and continue heads;
_representationNetworkis never backpropagated or updated, so its latents remain near-random and the downstream models learn on meaningless features. Dreamer relies on jointly training the encoder with the rest of the RSSM by propagating losses back through it; without that gradient path the agent cannot learn a usable latent space. (emergentmind.com)Please chain the combined world-model gradients (from dynamics/reward/continue losses) back into
_representationNetworkand call itsBackward/UpdateWeights, e.g. accumulate ∂loss/∂latent from each head, propagate it through the encoder, and update once per batch.
208-226: Correct the MSE gradients (missing 2/batch scaling).The gradients fed into the dynamics, reward, and continue networks are just
(predicted - target). The derivative of mean-squared error is2 * (predicted - target) / batchSize; omitting the factor and normalization inflates the effective learning rate and makes updates batch-size dependent. (en.wikipedia.org)Scale each gradient component by
2and divide bybatch.Count(using_numOps) before callingBackward/UpdateWeightsso the reported average loss matches the applied gradients.
263-271: Actor gradient must backprop through imagined value, not divide by length.Computing the actor gradient as
valueDiff / action.Lengthignores how the value changes with each action dimension, so the policy receives no directional signal and will not learn. Dreamer updates the policy by backpropagating analytic value gradients through the imagined trajectory (∂V/∂a chained with ∂a/∂θ). (emergentmind.com)Please replace this heuristic with a proper gradient from the critic/value estimate—e.g., differentiate the imagined return w.r.t. the action via the value network and dynamics, then feed that into
_actorNetwork.Backward.
279-306: Add a discount factor to imagined returns.
ImagineTrajectorycurrently sums raw rewards, so long rollouts can dominate and the value function no longer represents the standard discounted return used elsewhere in the agent. Reinforcement learning returns are defined with a discount γ to balance near- and long-term rewards and ensure convergence. (damek.github.io)Introduce a
_options.DiscountFactor(e.g., 0.99), maintain a running multiplier inside the loop, accumulateimaginedReturn += discount * reward, and updatediscount *= _options.DiscountFactoreach step.src/ReinforcementLearning/Agents/MuZero/MCTSNode.cs (1)
13-16: Add reward bookkeeping to nodes.Line 15 only tracks visit counts and Q-values, so the immediate reward from the dynamics model has nowhere to live and the backup ends up propagating γ v with r = 0—exactly the regression that was flagged in the earlier review. Please add per-action reward storage so the agent can persist and consume the reward during backup.
public Dictionary<int, MCTSNode<T>> Children { get; set; } = new(); public Dictionary<int, int> VisitCounts { get; set; } = new(); public Dictionary<int, T> QValues { get; set; } = new(); + public Dictionary<int, T> Rewards { get; set; } = new();src/ReinforcementLearning/Agents/MuZero/MuZeroAgent.cs (3)
166-173: Don’t push the same edge twice.Line 156 already pushes (node, action) onto
path; when expansion happens you callSelectActionPUCTagain on Line 169 and push the identical tuple on Line 172. Backup then bumpsVisitCountsandTotalVisitstwice per simulation on that edge, blowing up the statistics (same problem that was reported earlier). Guard the expansion block so it only appends when the edge was not already recorded.if (path.Count < _options.UnrollSteps) { int action = SelectActionPUCT(currentNode); - var child = ExpandNode(currentNode, action); - currentNode.Children[action] = child; - path.Add((currentNode, action)); - currentNode = child; + if (path.Count == 0 || path[path.Count - 1].node != currentNode || path[path.Count - 1].action != action) + { + path.Add((currentNode, action)); + } + + var child = ExpandNode(currentNode, action); + currentNode.Children[action] = child; + currentNode = child; }
190-204: Fix the backup math and include the reward.On Line 190 you increment
VisitCountsbefore using it to weight the running mean, so the first visit setsQto half the backed-up value; and because no reward is read, the backup uses γ v only. That combination was already called out as corrupting the tree statistics. Use the previous visit count in the averaging formula and fold in the stored reward before you discount.- pathNode.VisitCounts[pathAction]++; - pathNode.TotalVisits++; - - // Update Q-value: Q = (Q * n + v) / (n + 1) - var oldQ = pathNode.QValues[pathAction]; - var visitCount = NumOps.FromDouble(pathNode.VisitCounts[pathAction]); - var newQ = NumOps.Divide( - NumOps.Add(NumOps.Multiply(oldQ, visitCount), value), - NumOps.Add(visitCount, NumOps.One)); - - pathNode.QValues[pathAction] = newQ; - - // Discount value for parent - value = NumOps.Multiply(_options.DiscountFactor, value); + int previousVisits = pathNode.VisitCounts[pathAction]; + var oldQ = pathNode.QValues[pathAction]; + + var reward = pathNode.Rewards.TryGetValue(pathAction, out var storedReward) + ? storedReward + : NumOps.Zero; + + var backedValue = NumOps.Add(reward, NumOps.Multiply(_options.DiscountFactor, value)); + + var numerator = NumOps.Add( + NumOps.Multiply(oldQ, NumOps.FromDouble(previousVisits)), + backedValue); + pathNode.QValues[pathAction] = NumOps.Divide( + numerator, + NumOps.FromDouble(previousVisits + 1)); + + pathNode.VisitCounts[pathAction] = previousVisits + 1; + pathNode.TotalVisits++; + + value = backedValue;
250-266: Persist the dynamics reward.Line 255 discards
dynamicsOutput[_options.LatentStateSize], so the parent never records the immediate reward and backup has nothing to add. Please stash that value into the parent’s reward dictionary before returning the child node.// Extract next hidden state and reward var nextHiddenState = new Vector<T>(_options.LatentStateSize); for (int i = 0; i < _options.LatentStateSize; i++) { nextHiddenState[i] = dynamicsOutput[i]; } + + var reward = dynamicsOutput[_options.LatentStateSize]; + parent.Rewards[action] = reward; // Get value from prediction network var prediction = _predictionNetwork.Forward(nextHiddenState); var value = ExtractValue(prediction);
🧹 Nitpick comments (6)
src/ReinforcementLearning/Agents/Rainbow/RainbowDQNAgent.cs (3)
84-92: Remove redundant conditional branch.Both branches of the
UseNoisyNetworkscheck add identicalReLUactivation layers, making the conditional meaningless. Either implement actual noisy layer functionality or remove the check.Apply this diff to simplify:
- if (_options.UseNoisyNetworks) - { - // Add noise to weights for exploration (simplified) - network.AddLayer(new ActivationLayer<T>(new ReLU<T>())); - } - else - { - network.AddLayer(new ActivationLayer<T>(new ReLU<T>())); - } + network.AddLayer(new ActivationLayer<T>(new ReLU<T>()));
271-276: Simplify n-step discount computation.The loop computing
nStepDiscountcan be replaced with a single power operation for better clarity and efficiency.Apply this diff:
- var nStepDiscount = NumOps.One; - for (int n = 0; n < _options.NSteps; n++) - { - nStepDiscount = NumOps.Multiply(nStepDiscount, _options.DiscountFactor); - } + var nStepDiscount = NumOps.Power(_options.DiscountFactor, NumOps.FromDouble(_options.NSteps));Note: This assumes
NumOps.Poweris available. If not, consider adding it or keeping the loop with a comment explaining why.
316-342: Consider refactoring to eliminate code duplication.
ComputeQValuesFromNetworkduplicates the exact logic fromComputeQValues(lines 148-177). Consider refactoringComputeQValuesto callComputeQValuesFromNetwork(_onlineNetwork, state)to adhere to DRY principles.Apply this diff to eliminate duplication:
private Vector<T> ComputeQValues(Vector<T> state) { - var output = _onlineNetwork.Forward(state); - - if (_options.UseDistributional) - { - // Distributional RL: convert distribution to Q-values - var qValues = new Vector<T>(_options.ActionSize); - double deltaZ = (_options.VMax - _options.VMin) / (_options.NumAtoms - 1); - - for (int action = 0; action < _options.ActionSize; action++) - { - T qValue = NumOps.Zero; - for (int atom = 0; atom < _options.NumAtoms; atom++) - { - int idx = action * _options.NumAtoms + atom; - double z = _options.VMin + atom * deltaZ; - var prob = output[idx]; - qValue = NumOps.Add(qValue, NumOps.Multiply(prob, NumOps.FromDouble(z))); - } - qValues[action] = qValue; - } - - return qValues; - } - else - { - return output; - } + return ComputeQValuesFromNetwork(_onlineNetwork, state); }src/ReinforcementLearning/Agents/TRPO/TRPOAgent.cs (1)
278-294: Value function update uses simplified gradient computationThe value network gradient is computed as just the error (line 288), rather than using a proper loss function like MSE. For MSE loss
L = (y - ŷ)² / 2, the gradient with respect to the prediction should be-(y - ŷ), but line 288 uses positiveerror = returns[i] - predictedValue.While this might work if the network's backward pass handles it appropriately, a more standard approach would compute MSE loss and backpropagate properly:
- var error = NumOps.Subtract(returns[i], predictedValue); - - var gradient = new Vector<T>(1); - gradient[0] = error; - - _valueNetwork.Backward(gradient); + // MSE loss gradient: -(target - prediction) + var error = NumOps.Subtract(predictedValue, returns[i]); + + var gradient = new Vector<T>(1); + gradient[0] = error; + + _valueNetwork.Backward(gradient);src/ReinforcementLearning/Agents/DecisionTransformer/SequenceContext.cs (2)
12-14: Consider usinginitsetters for better encapsulation.The public setters allow external code to replace the entire list instances, which could lead to unexpected behavior. Since these collections are meant to be modified through their list methods (Add, RemoveAt), consider using
initsetters instead:- public List<Vector<T>> States { get; set; } = new(); - public List<Vector<T>> Actions { get; set; } = new(); - public List<T> ReturnsToGo { get; set; } = new(); + public List<Vector<T>> States { get; init; } = new(); + public List<Vector<T>> Actions { get; init; } = new(); + public List<T> ReturnsToGo { get; init; } = new();This prevents accidental replacement while still allowing collection initialization and modification via list methods.
5-16: Consider documenting the collection size invariant.The implementation allows
Actions.Countto be less thanStates.Count, which is correct for the Decision Transformer pattern where a state is added before the action is predicted. However, this invariant isn't documented. Consider adding a remark to clarify the expected relationship:/// <summary> /// Context window for sequence modeling in Decision Transformer. /// Maintains recent states, actions, and returns-to-go for transformer input. /// </summary> +/// <remarks> +/// The Actions list may contain fewer elements than States and ReturnsToGo +/// because states are added before actions are predicted and added. +/// </remarks> /// <typeparam name="T">The numeric type used for calculations.</typeparam>This helps future maintainers understand why
Lengthis based onStates.Countrather than requiring all collections to be synchronized.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (12)
src/ReinforcementLearning/Agents/A3C/A3CAgent.cs(1 hunks)src/ReinforcementLearning/Agents/A3C/WorkerNetworks.cs(1 hunks)src/ReinforcementLearning/Agents/DecisionTransformer/DecisionTransformerAgent.cs(1 hunks)src/ReinforcementLearning/Agents/DecisionTransformer/SequenceContext.cs(1 hunks)src/ReinforcementLearning/Agents/Dreamer/DreamerAgent.cs(1 hunks)src/ReinforcementLearning/Agents/MADDPG/MADDPGAgent.cs(1 hunks)src/ReinforcementLearning/Agents/MuZero/MCTSNode.cs(1 hunks)src/ReinforcementLearning/Agents/MuZero/MuZeroAgent.cs(1 hunks)src/ReinforcementLearning/Agents/QMIX/QMIXAgent.cs(1 hunks)src/ReinforcementLearning/Agents/Rainbow/RainbowDQNAgent.cs(1 hunks)src/ReinforcementLearning/Agents/TRPO/TRPOAgent.cs(1 hunks)src/ReinforcementLearning/Agents/WorldModels/WorldModelsAgent.cs(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (12)
src/ReinforcementLearning/Agents/Rainbow/RainbowDQNAgent.cs (2)
src/ReinforcementLearning/ReplayBuffers/PrioritizedReplayBuffer.cs (1)
UpdatePriorities(110-118)src/NeuralNetworks/Layers/DenseLayer.cs (1)
SetWeights(564-579)
src/ReinforcementLearning/Agents/A3C/WorkerNetworks.cs (1)
src/ReinforcementLearning/Agents/A3C/A3CAgent.cs (8)
T(376-380)NeuralNetwork(59-84)NeuralNetwork(86-101)List(266-293)List(295-318)Vector(103-187)Vector(258-264)Vector(395-398)
src/ReinforcementLearning/Agents/MuZero/MuZeroAgent.cs (2)
src/ReinforcementLearning/Agents/Dreamer/DreamerAgent.cs (13)
T(151-169)T(171-230)T(232-277)T(279-307)NeuralNetwork(84-99)NeuralNetwork(101-117)InitializeNetworks(65-82)InitializeReplayBuffer(119-122)Vector(124-144)Vector(309-321)Vector(337-340)StoreExperience(146-149)Dictionary(323-330)src/ReinforcementLearning/Agents/MuZero/MCTSNode.cs (1)
MCTSNode(10-18)
src/ReinforcementLearning/Agents/Dreamer/DreamerAgent.cs (2)
src/ReinforcementLearning/Agents/MADDPG/MADDPGAgent.cs (11)
T(195-219)T(221-256)T(258-300)NeuralNetwork(88-104)NeuralNetwork(106-123)InitializeReplayBuffer(125-128)Vector(133-154)Vector(156-160)Vector(356-377)Vector(379-391)Vector(407-410)src/Helpers/MathHelper.cs (1)
MathHelper(19-992)
src/ReinforcementLearning/Agents/A3C/A3CAgent.cs (7)
src/ReinforcementLearning/Agents/DecisionTransformer/DecisionTransformerAgent.cs (8)
T(182-227)Vector(117-120)Vector(125-156)Vector(158-174)Vector(256-259)Task(261-264)Task(266-270)StoreExperience(176-180)src/ReinforcementLearning/Agents/Dreamer/DreamerAgent.cs (12)
T(151-169)T(171-230)T(232-277)T(279-307)NeuralNetwork(84-99)NeuralNetwork(101-117)Vector(124-144)Vector(309-321)Vector(337-340)Task(342-345)Task(347-351)StoreExperience(146-149)src/ReinforcementLearning/Agents/MADDPG/MADDPGAgent.cs (14)
T(195-219)T(221-256)T(258-300)NeuralNetwork(88-104)NeuralNetwork(106-123)Vector(133-154)Vector(156-160)Vector(356-377)Vector(379-391)Vector(407-410)Task(412-415)Task(417-421)CopyNetworkWeights(341-354)StoreExperience(189-193)src/ReinforcementLearning/Agents/QMIX/QMIXAgent.cs (8)
T(193-277)T(408-419)NeuralNetwork(86-101)NeuralNetwork(103-121)Task(441-444)Task(446-450)CopyNetworkWeights(376-389)StoreExperience(186-191)src/Helpers/MathHelper.cs (1)
MathHelper(19-992)src/ReinforcementLearning/Agents/REINFORCE/REINFORCEAgent.cs (1)
ComputeReturns(240-279)src/NeuralNetworks/Layers/DenseLayer.cs (1)
SetWeights(564-579)
src/ReinforcementLearning/Agents/QMIX/QMIXAgent.cs (2)
src/ReinforcementLearning/Agents/A3C/A3CAgent.cs (4)
T(376-380)NeuralNetwork(59-84)NeuralNetwork(86-101)CopyNetworkWeights(356-369)src/NeuralNetworks/Layers/DenseLayer.cs (1)
SetWeights(564-579)
src/ReinforcementLearning/Agents/WorldModels/WorldModelsAgent.cs (2)
src/ReinforcementLearning/Agents/Dreamer/DreamerAgent.cs (12)
T(151-169)T(171-230)T(232-277)T(279-307)NeuralNetwork(84-99)NeuralNetwork(101-117)Vector(124-144)Vector(309-321)Vector(337-340)InitializeNetworks(65-82)InitializeReplayBuffer(119-122)StoreExperience(146-149)src/Helpers/MathHelper.cs (1)
MathHelper(19-992)
src/ReinforcementLearning/Agents/MuZero/MCTSNode.cs (1)
src/ReinforcementLearning/Agents/MuZero/MuZeroAgent.cs (4)
MCTSNode(241-267)T(279-282)T(289-343)Vector(94-115)
src/ReinforcementLearning/Agents/TRPO/TRPOAgent.cs (4)
src/ReinforcementLearning/Agents/A3C/A3CAgent.cs (4)
T(376-380)NeuralNetwork(59-84)NeuralNetwork(86-101)CopyNetworkWeights(356-369)src/Helpers/MathHelper.cs (1)
MathHelper(19-992)src/Helpers/StatisticsHelper.cs (1)
StatisticsHelper(17-6685)src/NeuralNetworks/Layers/DenseLayer.cs (1)
SetWeights(564-579)
src/ReinforcementLearning/Agents/DecisionTransformer/SequenceContext.cs (1)
src/ReinforcementLearning/Agents/DecisionTransformer/DecisionTransformerAgent.cs (5)
T(182-227)Vector(117-120)Vector(125-156)Vector(158-174)Vector(256-259)
src/ReinforcementLearning/Agents/MADDPG/MADDPGAgent.cs (2)
src/Helpers/MathHelper.cs (1)
MathHelper(19-992)src/NeuralNetworks/Layers/DenseLayer.cs (1)
SetWeights(564-579)
src/ReinforcementLearning/Agents/DecisionTransformer/DecisionTransformerAgent.cs (1)
src/ReinforcementLearning/Agents/DecisionTransformer/SequenceContext.cs (1)
SequenceContext(10-16)
🔇 Additional comments (17)
src/ReinforcementLearning/Agents/Rainbow/RainbowDQNAgent.cs (4)
54-72: Constructor and initialization look correct.The constructor properly initializes all required fields, creates both online and target networks, and sets up the prioritized replay buffer.
125-177: Action selection and Q-value computation look correct.The epsilon-greedy strategy properly accounts for noisy networks, and the distributional Q-value computation correctly aggregates atom probabilities with their support values.
179-234: N-step experience handling is correctly implemented.The n-step buffer accumulation, return computation, and buffer management logic are all correct. Epsilon decay and beta scheduling for importance sampling are properly handled.
344-406: Helper and utility methods are correctly implemented.Weight copying, ArgMax, metrics reporting, episode reset, and prediction methods all have correct logic and proper implementations.
src/ReinforcementLearning/Agents/QMIX/QMIXAgent.cs (5)
165-184: LGTM!The multi-agent experience storage correctly concatenates per-agent observations with global state and handles epsilon decay appropriately.
279-314: LGTM!The joint state decomposition correctly extracts per-agent states, global state, and per-agent actions using appropriate offset calculations.
316-374: LGTM!The concatenation helper methods correctly assemble vectors from components. The assumption that all agent vectors have the same length (line 332) is appropriate for this multi-agent setting.
376-419: LGTM!The utility methods for network weight copying and vector operations are implemented correctly.
421-450: LGTM!The lifecycle and prediction methods correctly implement the base class contract and provide appropriate metrics.
src/ReinforcementLearning/Agents/MADDPG/MADDPGAgent.cs (8)
39-59: LGTM!The class structure and initialization properly set up per-agent actor/critic networks with their target counterparts.
61-86: LGTM!Network initialization correctly creates and initializes per-agent actor/critic pairs with synchronized target networks.
88-123: LGTM!Actor and critic network architectures are appropriate for MADDPG: actors produce bounded continuous actions via Tanh, and critics are centralized to observe all agents' states and actions.
133-154: LGTM!Action selection correctly adds Gaussian exploration noise during training and clamps actions to valid bounds.
165-187: LGTM with note on design choice.The multi-agent experience storage uses average reward across agents. This is acceptable for cooperative settings but loses per-agent reward information. For competitive or mixed settings, consider storing per-agent rewards.
302-339: LGTM!Soft target network updates correctly implement Polyak averaging (τ * source + (1-τ) * target) for smooth target network evolution.
341-391: LGTM!Helper methods for network weight management and vector concatenation are implemented correctly.
393-421: LGTM!Lifecycle and prediction methods correctly implement the base class contract.
src/ReinforcementLearning/Agents/DecisionTransformer/DecisionTransformerAgent.cs
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
♻️ Duplicate comments (3)
src/ReinforcementLearning/Agents/A3C/A3CAgent.cs (3)
221-234: Critical concurrency issue with shared environment (already flagged).As noted in previous reviews, all workers share the same
environmentinstance. Concurrent calls toReset()andStep()will corrupt episode state unless the environment is explicitly thread-safe (which typical RL environments are not). Each worker needs its own environment instance, or workers should run sequentially.
286-292: Local policy network output is ignored (already flagged).As noted in previous reviews, line 288 computes
policyOutputfrom the local network but line 291 discards it and calls the globalSelectAction, which locks and forwards through_globalPolicyNetwork. Workers should sample actions from their local networks to ensure trajectories reflect local parameters.
348-382: Incorrect gradient computations (already flagged).As noted in previous reviews, the policy gradient (lines 363-365) is set to a constant
advantage * 0.1for all output dimensions, ignoring log-probabilities and action selection. Similarly, the value gradient (lines 375-377) is set directly to the TD error without computing the derivative of the loss. These are not proper gradients and will prevent meaningful learning.
🧹 Nitpick comments (2)
src/Models/Options/RainbowDQNOptions.cs (1)
23-24: Consider validating StateSize and ActionSize.Since these parameters are essential for agent initialization, consider adding validation to ensure they are positive values. This provides earlier error detection and a clearer contract.
For example, you could add validation in a constructor or via required properties with validation, though init-only properties make this more challenging. Alternatively, the consuming
RainbowDQNAgentcould validate these on construction.src/ReinforcementLearning/Agents/A3C/A3CAgent.cs (1)
294-321: Consider optimizing return computation for larger trajectories.The use of
Insert(0, runningReturn)in line 317 results in O(n²) time complexity for trajectory length n. For the defaultTMax=5, this is negligible, but if users configure larger trajectory lengths, this could become a bottleneck.Apply this diff to optimize:
T runningReturn = nextValue; for (int i = trajectory.Count - 1; i >= 0; i--) { var exp = trajectory[i]; if (exp.done) { runningReturn = exp.reward; } else { runningReturn = NumOps.Add(exp.reward, NumOps.Multiply(_options.DiscountFactor, runningReturn)); } - returns.Insert(0, runningReturn); + returns.Add(runningReturn); } + returns.Reverse(); return returns;
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
src/Models/Options/A3COptions.cs(1 hunks)src/Models/Options/RainbowDQNOptions.cs(1 hunks)src/Models/Options/TRPOOptions.cs(1 hunks)src/ReinforcementLearning/Agents/A3C/A3CAgent.cs(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
src/ReinforcementLearning/Agents/A3C/A3CAgent.cs (6)
src/ReinforcementLearning/Agents/MADDPG/MADDPGAgent.cs (11)
T(195-219)T(221-256)T(258-300)Vector(133-154)Vector(156-160)Vector(356-377)Vector(379-391)Vector(407-410)NeuralNetwork(88-104)NeuralNetwork(106-123)CopyNetworkWeights(341-354)src/ReinforcementLearning/Agents/QMIX/QMIXAgent.cs (6)
T(193-277)T(408-419)List(279-314)NeuralNetwork(86-101)NeuralNetwork(103-121)CopyNetworkWeights(376-389)src/ReinforcementLearning/Agents/TRPO/TRPOAgent.cs (6)
T(189-208)T(330-350)List(210-276)NeuralNetwork(66-91)NeuralNetwork(93-108)CopyNetworkWeights(352-365)src/Models/Options/A3COptions.cs (2)
A3COptions(32-63)A3COptions(55-62)src/Helpers/MathHelper.cs (1)
MathHelper(19-992)src/Helpers/StatisticsHelper.cs (1)
StatisticsHelper(17-6685)
src/Models/Options/RainbowDQNOptions.cs (1)
src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs (1)
ReinforcementLearningOptions(316-382)
src/Models/Options/TRPOOptions.cs (2)
src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs (1)
ReinforcementLearningOptions(316-382)src/Helpers/NeuralNetworkHelper.cs (1)
ILossFunction(49-76)
src/Models/Options/A3COptions.cs (2)
src/ReinforcementLearning/Agents/A3C/A3CAgent.cs (6)
T(395-399)List(294-321)List(323-346)Vector(131-215)Vector(286-292)Vector(414-417)src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs (1)
ReinforcementLearningOptions(316-382)
🔇 Additional comments (16)
src/Models/Options/TRPOOptions.cs (10)
1-4: LGTM!The imports are appropriate for this TRPO options class, covering interfaces, linear algebra types, loss functions, and RL agent infrastructure.
8-33: LGTM!Excellent documentation that balances technical accuracy with beginner-friendly explanations. The trust region concept and its connection to monotonic improvement are clearly explained.
34-34: LGTM!The class properly inherits from
ReinforcementLearningOptions<T>, enabling it to leverage common RL configuration properties while adding TRPO-specific parameters.
38-38: LGTM!Reasonable default of
falsefor discrete action spaces, which are common in RL environments.
39-40: LGTM!These properties are properly initialized in the constructor with standard RL values (0.001 for learning rate and 0.95 for GAE lambda).
42-48: LGTM!TRPO-specific hyperparameters have sensible defaults aligned with the original TRPO paper (Schulman et al., 2015). The trust region constraint (MaxKL = 0.01) and optimization parameters (conjugate gradient iterations, line search steps) are appropriate starting values.
50-51: LGTM!Training control defaults are appropriate: 2048 steps per update is typical for on-policy methods like TRPO, and 5 value function iterations per policy update helps stabilize training.
56-59: LGTM!Good design: nullable optimizer with clear documentation that Adam will be used by default. This provides flexibility while maintaining sensible defaults.
61-67: Constructor initializes core hyperparameters but leaves StateSize/ActionSize uninitialized.The constructor properly uses
NumericOperations<T>.Instancefor type-safe initialization of generic numeric properties. However, the criticalStateSizeandActionSizeproperties remain uninitialized (see previous comment).The initialization pattern for generic numeric types is correct, but this approval is contingent on addressing the uninitialized required properties.
53-54: LGTM.The default network architecture (two hidden layers of 64 units) is reasonable for many RL problems. The collection expression syntax
[64, 64]is fully supported, as the project targets C#latest.src/Models/Options/A3COptions.cs (1)
32-62: LGTM! Well-structured configuration class.The A3COptions class is well-designed with appropriate defaults and clear documentation. The use of init-only setters allows constructor defaults while still permitting override via object initializers, which is a good pattern for configuration classes.
src/ReinforcementLearning/Agents/A3C/A3CAgent.cs (5)
107-129: LGTM! Value network correctly configured.The value network architecture is appropriate with hidden layers, a single output unit, and MSE loss function for value function approximation.
131-215: LGTM! Action selection logic is sound.The action selection properly handles both continuous (Gaussian policy) and discrete (categorical) action spaces with appropriate sampling strategies for training vs. inference modes. The log_std clamping prevents numerical instability.
323-346: LGTM! Advantage computation is correct.The advantage calculation (returns minus values) followed by normalization is standard practice in A3C and correctly implemented. The zero-std check prevents division errors.
384-388: LGTM! Clean weight synchronization.Using
GetFlattenedParametersandUpdateParametersis a cleaner approach than manual layer-by-layer copying seen in other agents.
390-427: LGTM! Override methods correctly implemented.All override methods are appropriately implemented for A3C's on-policy, replay-free design. The delegation patterns and no-ops are correct for this algorithm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (10)
src/Models/Options/DecisionTransformerOptions.cs (2)
36-37: Add validation for required parameters.
StateSizeandActionSizehave no default values or validation, allowing them to remain at 0. This will cause runtime errors when instantiating a Decision Transformer agent. This issue was previously flagged but remains unresolved.Consider adding a parameterized constructor with validation:
+public DecisionTransformerOptions(int stateSize, int actionSize) +{ + if (stateSize <= 0) + throw new ArgumentException("StateSize must be greater than 0", nameof(stateSize)); + if (actionSize <= 0) + throw new ArgumentException("ActionSize must be greater than 0", nameof(actionSize)); + + StateSize = stateSize; + ActionSize = actionSize; +}
47-47: Reduce BufferSize default to avoid excessive memory usage.The default
BufferSize = 1,000,000can consume significant memory (potentially ~1GB with typical state/action dimensions). Since Decision Transformer is primarily offline, such a large buffer may be unnecessary. This issue was previously flagged but remains unresolved.Consider reducing the default and adding documentation:
-public int BufferSize { get; init; } = 1000000; +/// <summary> +/// Size of the experience replay buffer. Large values increase memory usage. +/// For typical use (StateSize=100, ActionSize=4), each experience ~1KB, so 100K entries ≈ 100MB. +/// </summary> +public int BufferSize { get; init; } = 100000;src/Models/Options/MADDPGOptions.cs (3)
3-3: Remove unused using directive.The
AiDotNet.LossFunctionsnamespace is not referenced in this file.Apply this diff:
-using AiDotNet.LossFunctions; - namespace AiDotNet.Models.Options;
36-38: Add validation for required dimension properties.The properties
NumAgents,StateSize, andActionSizeare essential for MADDPG configuration but lack validation. Users could initialize with zero or negative values, causing runtime errors.The previous review suggested either constructor parameters with validation (preferred) or property setters with validation. Since these use
initaccessors, consider making them constructor parameters:- public MADDPGOptions() + public MADDPGOptions(int numAgents, int stateSize, int actionSize) { + if (numAgents <= 0) throw new ArgumentOutOfRangeException(nameof(numAgents)); + if (stateSize <= 0) throw new ArgumentOutOfRangeException(nameof(stateSize)); + if (actionSize <= 0) throw new ArgumentOutOfRangeException(nameof(actionSize)); + + NumAgents = numAgents; + StateSize = stateSize; + ActionSize = actionSize; + var numOps = NumericOperations<T>.Instance; ActorLearningRate = numOps.FromDouble(0.0001); CriticLearningRate = numOps.FromDouble(0.001); TargetUpdateTau = numOps.FromDouble(0.001); }
46-47: Critical: Collection expression syntax incompatible with .NET Framework 4.6.The collection expression syntax
[128, 128]is a C# 12 feature that requires a later language version than .NET Framework 4.6 supports (which is limited to C# 7.3). This will cause compilation errors when targeting net46.Apply this diff to use compatible syntax:
- public List<int> ActorHiddenLayers { get; init; } = [128, 128]; - public List<int> CriticHiddenLayers { get; init; } = [128, 128]; + public List<int> ActorHiddenLayers { get; init; } = new List<int> { 128, 128 }; + public List<int> CriticHiddenLayers { get; init; } = new List<int> { 128, 128 };src/Models/Options/WorldModelsOptions.cs (2)
41-44: C# 12 collection expression incompatible with net462 target.Line 43 uses C# 12 collection expression syntax
[32, 64, 128, 256]which will fail compilation on net462. This issue was already flagged in previous reviews.
52-53: C# 12 collection expression incompatible with net462 target.Line 53 uses C# 12 collection expression syntax
[32]which will fail compilation on net462. This issue was already flagged in previous reviews.src/Models/Options/DreamerOptions.cs (3)
46-47: Training parameters need validation.
BatchLengthandImaginationHorizonshould be validated to ensure they're positive. Include these in theValidate()method.if (BatchLength <= 0) throw new ArgumentOutOfRangeException(nameof(BatchLength), "BatchLength must be greater than 0"); if (ImaginationHorizon <= 0) throw new ArgumentOutOfRangeException(nameof(ImaginationHorizon), "ImaginationHorizon must be greater than 0");
36-37: Critical: ObservationSize and ActionSize require validation.These properties default to
0, which is invalid for any RL agent configuration. Whileinitaccessors provide immutability after construction, they don't enforce non-zero values. Based on codebase patterns (e.g.,MixtureOfExpertsOptions.Validate()), add validation to ensure these are set to positive values.Consider adding a
Validate()method:+ /// <summary> + /// Validates the configuration options. + /// </summary> + public void Validate() + { + if (ObservationSize <= 0) + throw new ArgumentOutOfRangeException(nameof(ObservationSize), + "ObservationSize must be greater than 0"); + if (ActionSize <= 0) + throw new ArgumentOutOfRangeException(nameof(ActionSize), + "ActionSize must be greater than 0"); + // Add other validations as needed + }Based on learnings (from past review comments indicating this pattern exists in the codebase).
3-3: Remove unused using directive.The
using AiDotNet.LossFunctions;directive is not used anywhere in this file. No types from theLossFunctionsnamespace are referenced.Apply this diff to remove the unused directive:
using AiDotNet.Interfaces; using AiDotNet.LinearAlgebra; -using AiDotNet.LossFunctions; using AiDotNet.ReinforcementLearning.Agents;
🧹 Nitpick comments (4)
src/Models/Options/MADDPGOptions.cs (1)
44-44: Consider adding validation for ExplorationNoise.The
ExplorationNoiseproperty could be set to negative values, which may not be meaningful for exploration strategies.If validation is desired, consider using a backing field with a property getter/init accessor that validates the range:
private double _explorationNoise = 0.1; public double ExplorationNoise { get => _explorationNoise; init => _explorationNoise = value >= 0 ? value : throw new ArgumentOutOfRangeException(nameof(ExplorationNoise), "Must be non-negative"); }src/Models/Options/WorldModelsOptions.cs (1)
36-39: Consider documenting that ActionSize is required.The
ActionSizeproperty has no default value, making it effectively required. Consider adding XML documentation to clarify that users must set this property based on their environment's action space.Example:
+ /// <summary> + /// Size of the action space. Must be set based on the environment. + /// </summary> public int ActionSize { get; init; }src/Models/Options/DreamerOptions.cs (2)
40-43: Consider adding validation for architecture sizes.While these properties have reasonable defaults, they could still be set to invalid values (≤0) during initialization. Consider including validation checks for these parameters in a
Validate()method to maintain configuration integrity.Example validation:
if (LatentSize <= 0 || DeterministicSize <= 0 || StochasticSize <= 0 || HiddenSize <= 0) throw new ArgumentOutOfRangeException("All architecture sizes must be greater than 0");
50-52: Consider validating loss scales.Loss scales (
KLScale,RewardScale,ContinueScale) could be set to negative values. Depending on the algorithm's requirements, you may want to enforce non-negative constraints or document the valid range.Example:
if (KLScale < 0 || RewardScale < 0 || ContinueScale < 0) throw new ArgumentOutOfRangeException("Loss scales must be non-negative");
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
src/Models/Options/DecisionTransformerOptions.cs(1 hunks)src/Models/Options/DreamerOptions.cs(1 hunks)src/Models/Options/MADDPGOptions.cs(1 hunks)src/Models/Options/MuZeroOptions.cs(1 hunks)src/Models/Options/QMIXOptions.cs(1 hunks)src/Models/Options/WorldModelsOptions.cs(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- src/Models/Options/MuZeroOptions.cs
🧰 Additional context used
🧬 Code graph analysis (5)
src/Models/Options/WorldModelsOptions.cs (1)
src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs (1)
ReinforcementLearningOptions(296-362)
src/Models/Options/MADDPGOptions.cs (1)
src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs (1)
ReinforcementLearningOptions(296-362)
src/Models/Options/QMIXOptions.cs (1)
src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs (1)
ReinforcementLearningOptions(296-362)
src/Models/Options/DreamerOptions.cs (1)
src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs (1)
ReinforcementLearningOptions(296-362)
src/Models/Options/DecisionTransformerOptions.cs (1)
src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs (1)
ReinforcementLearningOptions(296-362)
🔇 Additional comments (16)
src/Models/Options/DecisionTransformerOptions.cs (3)
8-33: Excellent documentation!The XML documentation provides a clear, educational explanation of Decision Transformer's novel approach to RL, making it accessible to both experts and beginners. The "For Beginners" section effectively explains the paradigm shift from traditional RL to sequence modeling.
40-44: Verify NumHeads default value is intentional.The default
NumHeads = 1is unusually low for transformer architectures. Multi-head attention typically uses 4, 8, or more heads to capture different representation subspaces. A single head defeats the primary benefit of multi-head attention.Please confirm whether:
- This is intentional for simplicity/performance in RL contexts
- It should be increased to a more typical value (e.g., 4 or 8)
If this default is intentional for Decision Transformer, consider adding a comment explaining the rationale. Otherwise, apply this change:
-public int NumHeads { get; init; } = 1; +public int NumHeads { get; init; } = 4; // Standard for transformer architectures
49-52: LGTM: Well-designed optional parameter.The
Optimizerproperty follows a good pattern: optional customization with a documented sensible default (Adam). The XML documentation clearly communicates the fallback behavior.src/Models/Options/MADDPGOptions.cs (2)
8-33: Excellent documentation!The XML documentation provides comprehensive coverage of MADDPG concepts with beginner-friendly explanations and real-world examples.
54-60: Good use of NumericOperations for type-safe initialization.The constructor correctly uses
NumericOperations<T>.Instanceto initialize the learning rates and tau values in a type-safe manner.src/Models/Options/WorldModelsOptions.cs (5)
1-6: LGTM!The imports and namespace declaration are clean and follow project conventions.
8-34: LGTM!The documentation is comprehensive and beginner-friendly. The class structure and inheritance from
ReinforcementLearningOptions<T>are appropriate for World Models configuration.
46-50: LGTM!MDN-RNN parameter defaults are appropriate for World Models: hidden size, layer count, mixture components, and temperature are all standard values.
55-60: LGTM!Training parameter defaults are well-chosen: separate epoch counts for VAE and RNN training, appropriate CMA-ES parameters for controller evolution, and a reasonable rollout length for dream-based learning.
62-65: LGTM!The nullable
Optimizerproperty with clear documentation about the Adam default is a good design pattern that allows flexibility while maintaining sensible defaults.src/Models/Options/QMIXOptions.cs (4)
1-6: LGTM!The using statements and namespace declaration are appropriate for a QMIX options configuration class.
8-33: Excellent documentation!The XML documentation is comprehensive and includes a beginner-friendly explanation of the QMIX algorithm. The analogy and practical context make this very approachable for new users.
34-34: LGTM!The class declaration appropriately inherits from
ReinforcementLearningOptions<T>and uses a generic type parameter for flexibility.
45-48: LGTM!The
Optimizerproperty is well-documented and correctly uses a nullable type to allow for default behavior. This serves as a good example for how the other properties should be documented.src/Models/Options/DreamerOptions.cs (2)
8-34: Excellent documentation!The XML documentation is comprehensive, well-structured, and includes beginner-friendly explanations that clearly convey the Dreamer algorithm's purpose and components.
54-57: Good design for optional Optimizer property.The nullable
Optimizerproperty with clear documentation about the Adam default is a good pattern for providing flexibility while ensuring reasonable defaults.
…LearningAgentBase inheritance This commit completes the refactoring of all remaining RL agents to follow AiDotNet architecture patterns and project rules for .NET Framework compatibility. **Changes Applied to All Agents:** 1. **Removed readonly keywords** (.NET Framework compatibility): - TRPOAgent - DecisionTransformerAgent - MADDPGAgent - QMIXAgent - Dreamer Agent - MuZeroAgent - WorldModelsAgent 2. **Fixed inheritance** (MuZero and WorldModels): - Changed from `ReinforcementLearningAgentBase<T>` to `DeepReinforcementLearningAgentBase<T>` - All deep RL agents now properly inherit from Deep base class **Project Rules Followed:** - NO readonly keyword (violates .NET Framework compatibility) - Deep RL agents inherit from DeepReinforcementLearningAgentBase - Classical RL agents (future) inherit from ReinforcementLearningAgentBase **Status of All 8 RL Algorithms:** ✅ A3CAgent - Fully refactored with LayerHelper ✅ RainbowDQNAgent - Fully refactored with LayerHelper ✅ TRPOAgent - Already had LayerHelper, readonly removed ✅ DecisionTransformerAgent - Readonly removed, proper inheritance ✅ MADDPGAgent - Readonly removed, proper inheritance ✅ QMIXAgent - Readonly removed, proper inheritance ✅ DreamerAgent - Readonly removed, proper inheritance ✅ MuZeroAgent - Readonly removed, inheritance fixed ✅ WorldModelsAgent - Readonly removed, inheritance fixed All agents now follow: - Correct base class inheritance - No readonly keywords - Use INeuralNetwork<T> interfaces - Use LayerHelper for network creation (where implemented) - Register networks with Networks.Add() - Use IOptimizer with Adam defaults Resolves #394
…mentLearningAgentBase All deep RL agents (those using neural networks) now properly inherit from DeepReinforcementLearningAgentBase instead of ReinforcementLearningAgentBase. This architectural separation allows: - Deep RL agents to use neural network infrastructure (Networks list) - Classical RL agents (future) to use ReinforcementLearningAgentBase without neural networks Agents updated: - A2CAgent - CQLAgent - DDPGAgent - DQNAgent - DoubleDQNAgent - DuelingDQNAgent - IQLAgent - PPOAgent - REINFORCEAgent - SACAgent - TD3Agent Also removed readonly keywords for .NET Framework compatibility. Partial resolution of #394
This commit adds classical reinforcement learning algorithms that use ReinforcementLearningAgentBase WITHOUT neural networks, demonstrating the proper architectural separation. **New Classical RL Agents:** 1. **TabularQLearningAgent<T>:** - Foundational off-policy RL algorithm - Uses lookup table (Dictionary) for Q-values - No neural networks or function approximation - Perfect for discrete state/action spaces - Implements: Q(s,a) ← Q(s,a) + α[r + γ max Q(s',a') - Q(s,a)] 2. **SARSAAgent<T>:** - On-policy TD control algorithm - More conservative than Q-Learning - Learns from actual actions taken (including exploration) - Better for safety-critical environments - Implements: Q(s,a) ← Q(s,a) + α[r + γ Q(s',a') - Q(s,a)] **Options Classes:** - TabularQLearningOptions<T> : ReinforcementLearningOptions<T> - SARSAOptions<T> : ReinforcementLearningOptions<T> **Architecture Demonstrated:** Classical RL (no neural networks): Deep RL (with neural networks): **Benefits:** - Clear separation of classical vs deep RL - Classical methods don't carry neural network overhead - Proper foundation for beginners learning RL - Demonstrates tabular methods before function approximation Partial resolution of #394
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 16
♻️ Duplicate comments (48)
src/ReinforcementLearning/Agents/WorldModels/WorldModelsAgent.cs (4)
86-92: Fix compilation error:Random.NextDouble()requires an instance.This issue was already flagged in a previous review.
Random.NextDouble()is an instance method and cannot be called on the type. UseRandom.Shared.NextDouble()or create a privateRandominstance.
228-278: Encoder never receives gradients; VAE cannot learn.This critical issue was flagged in a previous review. The encoder's
Backwardmethod is never called (only decoder's is at line 274), so the encoder parameters remain static and the VAE cannot learn proper latent representations. The encoder must receive gradients from both the reconstruction loss (via reparameterization) and the KL divergence term.
283-327: RNN training uses stale hidden state from live rollout.This critical issue was flagged in a previous review. Line 295 concatenates
_rnnHiddenState(the agent's current state) with each sampled transition, but the correct hidden state for that transition is different. This creates inconsistent training targets and prevents the RNN from converging. The replay buffer must store the hidden state with each transition, or sample contiguous sequences to reconstruct it.
344-349: Fix compilation error:Random.NextDouble()requires an instance.Same issue as lines 86-92.
Random.NextDouble()is an instance method and must be called on an instance, not the type.src/ReinforcementLearning/Agents/DecisionTransformer/DecisionTransformerAgent.cs (2)
271-274: Fix compilation error: missing generic type parameter.Line 273 instantiates
SequenceContext()without the required generic type parameter<T>, causing a compilation error since the field is declared asSequenceContext<T>.Apply this diff to fix:
public override void ResetEpisode() { - _currentContext = new SequenceContext(); + _currentContext = new SequenceContext<T>(); }
202-247: Training never exposes non-zero previous actions.During training (line 217) every sample feeds a zero
previousActionvector intoConcatenateInputs, yet inference (lines 163-165) uses the most recent real action from_currentContext. Because_trajectoryBuffer(line 45) only stores(state, action, reward, returnToGo), the network never observes the non-zero previous-action feature it depends on at test time. This train/test distribution mismatch prevents the model from learning proper action conditioning.Persist the true previous action from each trajectory step:
- private List<(Vector<T> state, Vector<T> action, T reward, T returnToGo)> _trajectoryBuffer; + private List<(Vector<T> state, Vector<T> previousAction, Vector<T> action, T reward, T returnToGo)> _trajectoryBuffer;- _trajectoryBuffer = new List<(Vector<T>, Vector<T>, T, T)>(); + _trajectoryBuffer = new List<(Vector<T>, Vector<T>, Vector<T>, T, T)>();// Store trajectory with returns-to-go + var previousAction = new Vector<T>(_options.ActionSize); for (int i = 0; i < trajectory.Count; i++) { _trajectoryBuffer.Add(( trajectory[i].state, + previousAction, trajectory[i].action, trajectory[i].reward, returnsToGo[i] )); + + previousAction = trajectory[i].action; }- private List<(Vector<T> state, Vector<T> action, T reward, T returnToGo)> SampleBatch(int batchSize) + private List<(Vector<T> state, Vector<T> previousAction, Vector<T> action, T reward, T returnToGo)> SampleBatch(int batchSize) { - var batch = new List<(Vector<T>, Vector<T>, T, T)>(); + var batch = new List<(Vector<T>, Vector<T>, Vector<T>, T, T)>();- foreach (var (state, targetAction, reward, returnToGo) in batch) + foreach (var (state, previousAction, targetAction, reward, returnToGo) in batch) { - // For simplicity, use zero previous action - var previousAction = new Vector<T>(_options.ActionSize); var input = ConcatenateInputs(returnToGo, state, previousAction);src/ReinforcementLearning/Agents/Rainbow/RainbowDQNAgent.cs (2)
285-288: Critical: Gradient dimensionality mismatch when distributional RL is enabled.When
_options.UseDistributional = true, the network outputs_options.ActionSize * _options.NumAtomsvalues (line 93-94), but line 285 allocates a gradient vector of size_options.ActionSize. This dimensional mismatch will cause_onlineNetwork.Backward(gradient)to fail or behave incorrectly.This issue was previously flagged as addressed, but the fix is not present in this code.
Apply this diff to fix the gradient allocation:
- // Backpropagate - var gradient = new Vector<T>(_options.ActionSize); - gradient[actionIndex] = tdError; - _onlineNetwork.Backward(gradient); - _onlineNetwork.UpdateWeights(LearningRate); + // Backpropagate + int gradientSize = _options.UseDistributional + ? _options.ActionSize * _options.NumAtoms + : _options.ActionSize; + var gradient = new Vector<T>(gradientSize); + + if (_options.UseDistributional) + { + // For distributional RL, populate the atom slice for the selected action + int offset = actionIndex * _options.NumAtoms; + for (int atom = 0; atom < _options.NumAtoms; atom++) + { + gradient[offset + atom] = tdError; + } + } + else + { + gradient[actionIndex] = tdError; + } + + _onlineNetwork.Backward(gradient); + _onlineNetwork.UpdateWeights(LearningRate);
91-109: Critical: Dueling architecture still not implemented.The method name and documentation claim a dueling architecture, but the implementation simply calls
LayerHelper<T>.CreateDefaultDeepQNetworkLayers(), which builds a standard sequential DQN network. A true dueling architecture requires:
- Shared trunk (feature extraction)
- Parallel streams: separate value stream V(s) and advantage stream A(s,a)
- Aggregation: Q(s,a) = V(s) + (A(s,a) - mean(A(s,:)))
The
RainbowDQNOptions<T>class providesSharedLayers,ValueStreamLayers, andAdvantageStreamLayersconfigurations, but they are never used here. This issue was previously flagged as addressed, but the fix is not present in this code.Without dueling, the agent cannot separate state value from action advantages, contradicting Rainbow's claims and undermining one of its six core components.
Apply this approach to implement dueling:
private INeuralNetwork<T> CreateDuelingNetwork() { - int outputSize = _options.UseDistributional - ? _options.ActionSize * _options.NumAtoms - : _options.ActionSize; - - var architecture = new NeuralNetworkArchitecture<T> - { - InputSize = _options.StateSize, - OutputSize = outputSize, - TaskType = TaskType.Regression - }; - - // Use LayerHelper for production-ready network - var layers = LayerHelper<T>.CreateDefaultDeepQNetworkLayers(architecture); - - architecture.Layers = layers.ToList(); - return new NeuralNetwork<T>(architecture, LossFunction); + // TODO: Implement true dueling architecture: + // 1. Build shared trunk using _options.SharedLayers + // 2. Split into parallel value stream (_options.ValueStreamLayers) outputting 1 or NumAtoms + // 3. Split into parallel advantage stream (_options.AdvantageStreamLayers) outputting ActionSize or ActionSize*NumAtoms + // 4. Add aggregation layer: Q(s,a) = V(s) + (A(s,a) - mean_a(A(s,a))) + // 5. For distributional, apply aggregation per-atom + throw new NotImplementedException("Dueling architecture not yet implemented. Use standard DQN for now."); }Alternatively, if dueling is not planned, remove all dueling claims from documentation and rename this method to
CreateNetwork().src/ReinforcementLearning/Agents/TRPO/TRPOAgent.cs (5)
48-48: Trajectory buffer still missingnextStateparameterThis issue was flagged in previous reviews but remains unresolved. The tuple should include
nextStateto enable proper bootstrapping of returns and advantages when episodes don't terminate at batch boundaries.
207-216:nextStateparameter ignored when storing experienceThe method accepts
nextStatebut doesn't persist it in the trajectory buffer (line 209). This is part of the critical issue flagged in previous reviews.
258-290: Return and advantage computation logic is flawedThe Monte Carlo return computation (lines 258-269) and GAE computation (lines 275-290) have fundamental issues:
- Line 262: When
done=true, correctly setsrunningReturn = rewards[i](no bootstrap).- Lines 264-267: When
done=false, it should bootstrap fromV(nextState)but instead continues from the previousrunningReturn, which causes incorrect propagation when batch boundaries cut through episodes.- Line 277: Uses
values[i+1]as the bootstrap value, but this is the value of the next buffered experience, not thenextStateof the current transition. These are conceptually different.This is a consequence of the missing
nextStatein the trajectory buffer, which prevents proper bootstrapping.
333-356: Policy update still ignores chosen actions and log-probabilitiesThis critical issue was flagged in previous reviews but remains unresolved. The policy gradient (line 350) is uniform across all output neurons and doesn't use the recorded actions or compute the score function
∇_θ log π_θ(a|s).Without proper score function gradients, the agent cannot learn meaningful policies.
359-379: KL divergence regression: continuous action support missingA previous review flagged this issue and marked it as "✅ Addressed in commit f8d6fa8", but the current code still only implements discrete KL divergence (lines 361-362). When
_options.IsContinuousis true,policyOutputcontains Gaussian parameters (mean and log_std), not discrete probabilities. The current implementation treats these as discrete probabilities and computes meaningless KL values, breaking the trust region constraint for continuous action spaces.For continuous actions, implement KL divergence between two Gaussian distributions as suggested in the previous review. The fix should branch on
_options.IsContinuousand compute the diagonal Gaussian KL:private T ComputeKL(Vector<T> newDist, Vector<T> oldDist) { - // Simplified KL divergence for discrete distributions - // KL(old || new) = sum(old * log(old / new)) - T kl = NumOps.Zero; - - for (int i = 0; i < newDist.Length; i++) + if (_options.IsContinuous) + { + // KL divergence for Gaussian distributions + T kl = NumOps.Zero; + int actionSize = _options.ActionSize; + + for (int i = 0; i < actionSize; i++) + { + var oldMean = oldDist[i]; + var oldLogStd = oldDist[actionSize + i]; + var newMean = newDist[i]; + var newLogStd = newDist[actionSize + i]; + + var oldVar = MathHelper.Exp(NumOps.Multiply(NumOps.FromDouble(2), oldLogStd)); + var newVar = MathHelper.Exp(NumOps.Multiply(NumOps.FromDouble(2), newLogStd)); + + // KL = log(σ_new/σ_old) + (σ_old² + (μ_old-μ_new)²)/(2σ_new²) - 1/2 + var meanDiff = NumOps.Subtract(oldMean, newMean); + var meanDiffSq = NumOps.Multiply(meanDiff, meanDiff); + + var term1 = NumOps.Subtract(newLogStd, oldLogStd); + var term2 = NumOps.Divide(NumOps.Add(oldVar, meanDiffSq), NumOps.Multiply(NumOps.FromDouble(2), newVar)); + var term3 = NumOps.FromDouble(0.5); + + kl = NumOps.Add(kl, NumOps.Subtract(NumOps.Add(term1, term2), term3)); + } + return kl; + } + else { - var oldProb = oldDist[i]; - var newProb = newDist[i]; - - if (NumOps.Compare(oldProb, NumOps.Zero) > 0 && NumOps.Compare(newProb, NumOps.Zero) > 0) + // KL divergence for discrete distributions + T kl = NumOps.Zero; + for (int i = 0; i < newDist.Length; i++) { - var ratio = NumOps.Divide(oldProb, newProb); - var logRatio = MathHelper.Log(ratio); - kl = NumOps.Add(kl, NumOps.Multiply(oldProb, logRatio)); + var oldProb = oldDist[i]; + var newProb = newDist[i]; + + if (NumOps.Compare(oldProb, NumOps.Zero) > 0 && NumOps.Compare(newProb, NumOps.Zero) > 0) + { + var ratio = NumOps.Divide(oldProb, newProb); + var logRatio = MathHelper.Log(ratio); + kl = NumOps.Add(kl, NumOps.Multiply(oldProb, logRatio)); + } } + return kl; } - - return kl; }src/ReinforcementLearning/Agents/MADDPG/MADDPGAgent.cs (2)
254-289: Critical: Two distinct bugs in critic update.Issue 1 (Line 261): Target Q uses stored action instead of target actors (DUPLICATE of unresolved past review).
The past review correctly identified that MADDPG's Bellman target requires rolling target actors forward on
nextStateto produce the next joint action, then querying the target critic with that next joint action. Line 261 still usesexperience.action(the current stored action), not the next action from target actors, which systematically biases the critic.Issue 2 (Line 283): Gradient has incorrect sign.
For TD error
error = target - currentQ, the gradient of MSE loss w.r.t.currentQis2*(currentQ - target) = -2*error. Line 283 useserrordirectly, which has the opposite sign. While the factor of 2 can be absorbed into the learning rate, the wrong sign will cause the critic to move away from the target instead of toward it.Apply this combined diff to fix both issues:
private T UpdateCritic(int agentId, List<(Vector<T> state, Vector<T> action, T reward, Vector<T> nextState, bool done)> batch) { T totalLoss = NumOps.Zero; foreach (var experience in batch) { - // Compute target using target networks (centralized) - var targetQ = _targetCriticNetworks[agentId].Forward(ConcatenateStateAction(experience.nextState, experience.action))[0]; + // Compute next joint action from target actors + var jointNextAction = experience.action.Clone(); + for (int otherAgent = 0; otherAgent < _options.NumAgents; otherAgent++) + { + int stateOffset = otherAgent * _options.StateSize; + var nextAgentState = new Vector<T>(_options.StateSize); + for (int i = 0; i < _options.StateSize; i++) + { + nextAgentState[stateOffset + i] = experience.nextState[stateOffset + i]; + } + var targetAction = _targetActorNetworks[otherAgent].Forward(nextAgentState); + for (int i = 0; i < _options.ActionSize; i++) + { + jointNextAction[otherAgent * _options.ActionSize + i] = targetAction[i]; + } + } + var targetInput = ConcatenateStateAction(experience.nextState, jointNextAction); + var targetQ = _targetCriticNetworks[agentId].Forward(targetInput)[0]; T target; if (experience.done) { target = experience.reward; } else { target = NumOps.Add(experience.reward, NumOps.Multiply(_options.DiscountFactor, targetQ)); } // Current Q-value var currentQ = _criticNetworks[agentId].Forward(ConcatenateStateAction(experience.state, experience.action))[0]; // TD error var error = NumOps.Subtract(target, currentQ); var loss = NumOps.Multiply(error, error); totalLoss = NumOps.Add(totalLoss, loss); // Backpropagate with correct gradient sign var gradient = new Vector<T>(1); - gradient[0] = error; + gradient[0] = NumOps.Negate(NumOps.Multiply(NumOps.FromDouble(2.0), error)); _criticNetworks[agentId].Backward(gradient); _criticNetworks[agentId].UpdateWeights(_options.CriticLearningRate); } return NumOps.Divide(totalLoss, NumOps.FromDouble(batch.Count)); }
291-333: Critical: Actor gradient still fabricated instead of backpropagated (unresolved past review).Lines 321-326 create a constant gradient vector by dividing Q-value by action size, which does not implement the deterministic policy gradient. The past review (lines 316-330) correctly identified this bug and was marked "✅ Addressed in commit cdc12d1," yet the current code still shows the incorrect implementation. Without backpropagating
dQ/dActionfrom the critic, the actor never receives the true policy gradient and cannot optimize against the critic.Apply the fix from the past review:
- var qValue = _criticNetworks[agentId].Forward(ConcatenateStateAction(experience.state, jointAction))[0]; - - // Actor loss: maximize Q-value - totalLoss = NumOps.Add(totalLoss, NumOps.Negate(qValue)); - - // Simplified gradient for actor - var actorGradient = new Vector<T>(_options.ActionSize); - for (int i = 0; i < _options.ActionSize; i++) - { - actorGradient[i] = NumOps.Divide(qValue, NumOps.FromDouble(_options.ActionSize)); - } - - _actorNetworks[agentId].Backward(actorGradient); + var stateAction = ConcatenateStateAction(experience.state, jointAction); + var qValue = _criticNetworks[agentId].Forward(stateAction)[0]; + totalLoss = NumOps.Add(totalLoss, NumOps.Negate(qValue)); + + // Backpropagate through critic to get action gradient + var criticOutputGradient = new Vector<T>(1); + criticOutputGradient[0] = NumOps.Negate(NumOps.One); + var criticInputGradient = _criticNetworks[agentId].Backward(criticOutputGradient); + + // Extract gradient w.r.t. this agent's action + int actionOffset = _options.StateSize * _options.NumAgents + agentId * _options.ActionSize; + var actorGradient = new Vector<T>(_options.ActionSize); + for (int i = 0; i < _options.ActionSize; i++) + { + actorGradient[i] = criticInputGradient[actionOffset + i]; + } + + _actorNetworks[agentId].Backward(actorGradient); _actorNetworks[agentId].UpdateWeights(_options.ActorLearningRate);src/ReinforcementLearning/Agents/Dreamer/DreamerAgent.cs (5)
182-241: Critical: Representation network is never trained.The
_representationNetworkencodes observations to latent states (lines 189-190) but there are no calls to_representationNetwork.Backward()or_representationNetwork.UpdateWeights()in the training loop. Without training the encoder, latent representations won't learn meaningful features from observations, severely limiting the agent's ability to learn. The Dreamer algorithm requires training the representation network jointly with the dynamics model by backpropagating reconstruction and prediction errors through the encoder.
186-238: Inconsistent batch processing: per-sample updates vs. batch averaging.Weight updates are applied after each sample in the batch (lines 227, 232, 237), but the returned loss is averaged over the batch (line 240). Standard practice is either accumulate gradients across the batch then update once, or update per sample and return the sum of losses. The current approach creates a mismatch where networks see
BatchSizesequential updates perTrain()call but the reported loss suggests a single batch update, leading to noisier updates and learning rate effectively multiplied by batch size.
220-237: Fix gradient computation to match MSE loss derivative formula.The gradient calculations at lines 223, 230, and 235 are missing the factor of 2 and batch normalization from the MSE loss derivative. The derivative of MSE is
2*(predicted-actual)/n, but the code computes only the difference. While the factor of 2 can be absorbed into the learning rate, it's more conventional to use the correct gradient formula including batch normalization for stable training.
273-282: Actor gradient is overly simplified and likely incorrect.The actor gradient is computed as
valueDiff / action.Length(lines 276-278), which doesn't follow standard policy gradient formulations. Policy gradient methods typically require computing ∂Value/∂action via the value network, computing ∂action/∂θ (actor parameters), and chaining these gradients. The current heuristic shortcut may not provide meaningful gradient signals for the actor to learn effectively.
290-318: Missing discount factor in imagination rollout.The imagined return accumulates rewards without any discount factor (line 303). Standard RL practice applies a discount factor γ to balance immediate vs. future rewards. Without discounting, the agent may overvalue long-term speculative rewards and training can be unstable with long imagination horizons.
src/ReinforcementLearning/Agents/REINFORCE/REINFORCEAgent.cs (1)
207-226: Policy gradient implementation is incorrect.
_policyNetwork.Backwardis being fed a vector whose entries are uniformlyloss * 0.01(line 223), so the update direction is independent of the policy logits, sampled action, or state. This means every parameter receives the same scaled adjustment each step, and the REINFORCE update is not the true∇θ log π(a|s) * G_t. The agent cannot learn with this placeholder gradient. Please backpropagate the actual gradient of the log-probability scaled by the return.src/ReinforcementLearning/Agents/IQL/IQLAgent.cs (2)
239-249: Fix expectile gradient sign and weight.
UpdateValueFunctionfeedsdiffdirectly into backprop (line 246), but the expectile loss isw * (q - v)^2. The derivative w.r.t.vshould be-2 * w * (q - v). As written, the code drops both the weight and the negative sign, so value updates push in the wrong direction and ignore the expectile asymmetry, preventing the critic from converging.
294-311: Correct the Q-network loss gradient.For the Q-updates, passing
targetQ - qValuestraight intoBackward(lines 299, 309) is incorrect. The MSE derivative should be-2 * (target - prediction); without that sign, critic steps move away from the target instead of toward it, making training unstable.src/ReinforcementLearning/Agents/CQL/CQLAgent.cs (2)
247-271: Propagate the conservative penalty's gradient.
cqlPenaltyis added toq1Loss/q2Loss(lines 252, 256), but the tensors backpropagated (lines 259-268) only contain the TD error. The conservative term never influences the gradients, so this agent reduces to vanilla double-Q and violates the CQL objective. Please fold the penalty into the gradients before callingBackward.
309-341: Policy gradient sign and variance outputs are wrong.
policyGradis populated withactionGrad(lines 333-337) without negating it, so gradient descent drives the actor toward lower Q-values. Additionally, the second half of the vector (log-σ outputs) stays zero, leaving exploration variance frozen. Please negate the gradient to perform ascent on Q and populate allActionSize * 2components so both mean and log-σ parameters update.src/ReinforcementLearning/Agents/DoubleDQN/DoubleDQNAgent.cs (3)
58-85: Fix null handling before calling the base constructor.Line 58 dereferences
optionswhile building the baseReinforcementLearningOptions, so a null caller triggers aNullReferenceExceptionbefore the guard at line 74 runs. Validateoptionsfirst and only then build the base options, or use a helper method to validate before the base constructor call.
290-298: Keep the target network in sync when external parameters are injected.Line 297 updates only the online network, leaving
_targetNetworkwith stale weights. This creates inconsistent behavior afterSetParametersand during cloning. Copy the online weights to the target network immediately after updating parameters.
327-341:ComputeGradientsis returning weights, not gradients.After backprop, line 338 pulls the parameters via
GetParameters(), so callers apply weight values as gradients, wrecking updates. Return the actual flattened gradients instead (e.g., call_qNetwork.GetFlattenedGradients()and convert to Matrix).src/ReinforcementLearning/Agents/DDPG/DDPGAgent.cs (5)
59-90: Fix null handling before calling the base constructor.Line 59 dereferences
optionswhile creating the base options object, so a null caller will see aNullReferenceExceptioninstead of the intended argument guard at line 71. Guardoptionsbefore constructing the base arguments, or use a helper method.
242-270: Implement the true policy gradient for the actor update.Lines 259-263 multiply the scalar Q-value by a hard-coded
-0.01for every action dimension, completely ignoring ∂Q/∂a from the critic. Without the critic's gradient, the actor receives no meaningful signal and DDPG collapses. Backpropagate through the critic to obtain the gradient of Q with respect to the action, then feed that vector into the actor's backward pass.
435-439: Stop returning weights in place of gradients.
ComputeGradientscurrently returnsGetParameters()and a zero loss (line 438). Any caller will interpret the model weights as gradients, leading to nonsensical optimization steps. Either compute the real gradients or throw aNotSupportedException.
442-445: MakeApplyGradientseither work or fail fast.Leaving
ApplyGradientsempty silently drops optimizer updates for any external trainer that relies on the interface. At minimum, throwNotSupportedExceptionso callers do not assume the update succeeded.
475-492: Fix the Ornstein–Uhlenbeck diffusion term sign.Lines 481-485 subtract the Gaussian term, producing
dx = -θx − σN(0,1), which drives the process negative and breaks exploration. The stochastic term must be added: compute drift and diffusion separately, add them, then update the state.src/ReinforcementLearning/Agents/TD3/TD3Agent.cs (1)
177-274: Align ReplayBuffer usage with Experience API.
ReplayBuffer<T>.Sample(...)returns aList<Experience<T>>, not value tuples. The tuple signatures and camelCase member accessors (experience.nextState,experience.reward, etc.) won't compile. Update these methods to acceptExperience<T>objects and use PascalCase properties (Experience.State,Experience.Action,Experience.Reward,Experience.NextState,Experience.Done).src/ReinforcementLearning/Agents/QMIX/QMIXAgent.cs (2)
132-141: Restore QMIX monotonic mixer constraintsThe mixing network is still a plain feed-forward stack with unconstrained weights. QMIX relies on a hypernetwork-generated, non-negative mixing network so the joint argmax stays aligned with per-agent argmaxes. Without enforcing monotonicity, decentralized execution breaks. Please switch to a hypernetwork (or other non-negative parameterization such as SoftPlus) that conditions on the global state and guarantees non-negative mixer weights/biases.
264-283: Backpropagate TD error with correct sign and mixer Jacobian
mixingGradient[0] = tdErrorhas the wrong sign/magnitude (should be-2 * tdErrorfor the squared TD loss), and the agent networks still receive an equal share of the TD error instead of the mixer’s actual partial derivatives. Capture the input-gradient from_mixingNetwork.Backward(...), slice the components per agent action, and feed those through each_agentNetworkwith the proper-2 * tdErrorfactor. Without this, gradients are incorrect and the algorithm will not converge.src/ReinforcementLearning/Agents/SAC/SACAgent.cs (1)
346-352: Actor gradient is still placeholder
grad[i] = policyLoss * 0.01remains a hard-coded stub; it ignores the policy’s actual derivatives w.r.t. logits/means, so the actor never receives the SAC policy gradient. Please backpropagate the scalar loss through_policyNetwork(or derive the correct gradient analytically) before updating parameters.src/ReinforcementLearning/Agents/A2C/A2CAgent.cs (1)
229-260: No gradients computed before updates
Train()still accumulates losses but never calls_policyNetwork.Backward(...)or_valueNetwork.Backward(...). Consequently,GetFlattenedGradients()returns stale/zero data and both updates are no-ops. Please backpropagate the policy loss (e.g.,-log_prob * advantage) and the value loss before invoking the update helpers.src/ReinforcementLearning/Agents/DQN/DQNAgent.cs (1)
58-74: Null check must happen before the base callThe base-initializer still dereferences
optionsbefore the guard at Line 73, so a null caller hits aNullReferenceException. Move the null guard (or a helper that throws) into the expression used in the base call before reading any properties.src/ReinforcementLearning/Agents/DuelingDQN/DuelingDQNAgent.cs (7)
450-472: Backward pass is still a stub
Backwardstops after computing value/advantage scalars and never propagates through the layer stacks, so no weights receive gradients. Please complete the backprop by running the gradients through_advantageLayers,_valueLayers, and the shared layers.
475-489: UpdateWeights placeholder prevents learning
UpdateWeightsstill contains empty loops, so even with gradients nothing updates. Invoke each layer’sUpdateWeights(learningRate)(or equivalent) to apply accumulated gradients.
492-515: Flattened parameter export is unimplemented
GetFlattenedParametersreturns an empty matrix because it never collects layer weights/biases. Populate the list from shared, value, and advantage layers so serialization/cloning works.
517-520: SetFlattenedParameters is still emptyWithout loading weights back into the layers, cloning, deserializing, and target sync all fail. Iterate through each layer and assign its parameters from the provided matrix.
522-533: Serialize/Deserialize are stubsBoth methods remain empty, so the agent cannot persist or restore its state. Implement serialization for all layer weights/biases (and architecture metadata) and mirror it in
Deserialize.
299-313: ComputeGradients must return real gradientsLike DQN, this method returns
GetParameters()instead of the gradient vector produced by_qNetwork.Backward(...). Return the flattened gradients to make the external optimizer hooks usable.
316-319: ApplyGradients overwrites parameters with raw gradientsSetting the network weights equal to the gradient matrix zeroes learning. Subtract
learningRate * gradientsfrom the current parameters before applying them.src/ReinforcementLearning/Agents/PPO/PPOAgent.cs (2)
316-375: Previous review comment addresses GAE terminal state handling.The terminal state value handling issue in GAE computation has already been flagged in a previous review comment. Please refer to that comment for the required fix.
478-542: Previous review comment addresses policy gradient implementation issues.The policy gradient implementation problems have already been flagged in detail in a previous review comment. Please refer to that comment for the required fix regarding the clipped surrogate objective gradient.
🧹 Nitpick comments (10)
src/ReinforcementLearning/Agents/DecisionTransformer/DecisionTransformerAgent.cs (2)
88-97: Simplify final layer replacement.Line 93 extracts the input size via
denseLayer.GetWeights().Rows, which is fragile if weights aren't initialized or the layer structure differs. SincehiddenLayerSizeis known from the architecture parameters, use it directly.Apply this diff:
// Override final activation to Tanh for continuous actions var lastLayer = layers[layers.Count - 1]; if (lastLayer is DenseLayer<T> denseLayer) { layers[layers.Count - 1] = new DenseLayer<T>( - denseLayer.GetWeights().Rows, + _options.EmbeddingDim, _options.ActionSize, new TanhActivation<T>() ); }
118-122: Optimize returns-to-go computation.Line 121 uses
Insert(0, returnToGo)which is O(n) per insertion, resulting in O(n²) complexity for each trajectory. For large trajectories, this becomes inefficient.Apply this diff:
// Compute returns-to-go for this trajectory T returnToGo = NumOps.Zero; var returnsToGo = new List<T>(); for (int i = trajectory.Count - 1; i >= 0; i--) { returnToGo = NumOps.Add(trajectory[i].reward, returnToGo); - returnsToGo.Insert(0, returnToGo); + returnsToGo.Add(returnToGo); } + + returnsToGo.Reverse();src/ReinforcementLearning/Agents/Rainbow/RainbowDQNAgent.cs (1)
139-166: Eliminate code duplication between ComputeQValues methods.The logic in
ComputeQValues(lines 139-166) is identical toComputeQValuesFromNetwork(lines 305-331). This duplication makes maintenance harder and increases the risk of divergence.Apply this diff to eliminate duplication:
private Vector<T> ComputeQValues(Vector<T> state) { - var output = _onlineNetwork.Forward(state); - - if (_options.UseDistributional) - { - // Distributional RL: convert distribution to Q-values - var qValues = new Vector<T>(_options.ActionSize); - double deltaZ = (_options.VMax - _options.VMin) / (_options.NumAtoms - 1); - - for (int action = 0; action < _options.ActionSize; action++) - { - T qValue = NumOps.Zero; - for (int atom = 0; atom < _options.NumAtoms; atom++) - { - int idx = action * _options.NumAtoms + atom; - double z = _options.VMin + atom * deltaZ; - var prob = output[idx]; - qValue = NumOps.Add(qValue, NumOps.Multiply(prob, NumOps.FromDouble(z))); - } - qValues[action] = qValue; - } - - return qValues; - } - - return output; + return ComputeQValuesFromNetwork(_onlineNetwork, state); }src/ReinforcementLearning/Agents/TRPO/TRPOAgent.cs (3)
94-98: Simplify hidden layer size logicThe pattern
FirstOrDefault() > 0 ? First() : 128is redundant and could be simplified. Additionally, ifPolicyHiddenLayersis empty,hiddenLayerCountwill be 0, makinghiddenLayerSizeirrelevant.Consider this cleaner approach:
- var layers = LayerHelper<T>.CreateDefaultFeedForwardLayers( - architecture, - hiddenLayerCount: _options.PolicyHiddenLayers.Count, - hiddenLayerSize: _options.PolicyHiddenLayers.FirstOrDefault() > 0 ? _options.PolicyHiddenLayers.First() : 128 - ).ToList(); + int hiddenLayerCount = _options.PolicyHiddenLayers.Count; + int hiddenLayerSize = (hiddenLayerCount > 0) ? _options.PolicyHiddenLayers[0] : 128; + var layers = LayerHelper<T>.CreateDefaultFeedForwardLayers( + architecture, + hiddenLayerCount: hiddenLayerCount, + hiddenLayerSize: hiddenLayerSize + ).ToList();
129-133: Simplify hidden layer size logicSame redundant pattern as in
CreatePolicyNetwork.- var layers = LayerHelper<T>.CreateDefaultFeedForwardLayers( - architecture, - hiddenLayerCount: _options.ValueHiddenLayers.Count, - hiddenLayerSize: _options.ValueHiddenLayers.FirstOrDefault() > 0 ? _options.ValueHiddenLayers.First() : 128 - ); + int hiddenLayerCount = _options.ValueHiddenLayers.Count; + int hiddenLayerSize = (hiddenLayerCount > 0) ? _options.ValueHiddenLayers[0] : 128; + var layers = LayerHelper<T>.CreateDefaultFeedForwardLayers( + architecture, + hiddenLayerCount: hiddenLayerCount, + hiddenLayerSize: hiddenLayerSize + );
354-354: Prefer configured learning rate over hardcoded valueThe policy update uses a hardcoded learning rate
0.001, while the value function update uses_options.ValueLearningRate(line 320). Consider adding aPolicyLearningRateproperty toTRPOOptionsfor consistency and configurability.- _policyNetwork.UpdateWeights(NumOps.FromDouble(0.001)); // Very small LR for trust region + _policyNetwork.UpdateWeights(_options.PolicyLearningRate); // Trust region learning ratesrc/ReinforcementLearning/Agents/MADDPG/MADDPGAgent.cs (2)
135-156: Consider clarifying hidden layer configuration logic.Lines 150-151 use the same
FirstOrDefault()pattern as the actor network: ifCriticHiddenLayersis empty,FirstOrDefault()returns 0, yet the fallback128is never used whenhiddenLayerCount=0. While not incorrect (zero hidden layers means direct input→output), the conditional on line 151 is misleading.Simplify for clarity:
- hiddenLayerCount: _options.CriticHiddenLayers.Count, - hiddenLayerSize: _options.CriticHiddenLayers.FirstOrDefault() > 0 ? _options.CriticHiddenLayers.First() : 128 + hiddenLayerCount: _options.CriticHiddenLayers.Count, + hiddenLayerSize: _options.CriticHiddenLayers.FirstOrDefault() != 0 ? _options.CriticHiddenLayers.First() : 128Or use
FirstOrDefault(128)if the API supports it.
395-402: LGTM!Basic metrics are captured. Consider adding training loss, average Q-values, or per-agent metrics for richer monitoring.
src/ReinforcementLearning/Agents/PPO/PPOAgent.cs (1)
237-245: Consider using type-safe epsilon for numerical stability.Line 244 uses a hardcoded
1e-10epsilon as a double literal. For better numerical stability with generic typeT, consider using a small epsilon relative to the numeric type (e.g.,NumOps.FromDouble(1e-10)or a configurable epsilon inPPOOptions). The same issue appears in line 369.Apply this diff to use a type-safe epsilon:
- return NumOps.FromDouble(Math.Log(NumOps.ToDouble(prob) + 1e-10)); + var epsilon = NumOps.FromDouble(1e-10); + return NumOps.FromDouble(Math.Log(NumOps.ToDouble(prob) + NumOps.ToDouble(epsilon)));src/Models/Options/TabularQLearningOptions.cs (1)
42-55: Avoid hiding the base epsilon schedule
ReinforcementLearningOptions<T>already providesEpsilonStart,EpsilonEnd, andEpsilonDecaywith the same defaults (see src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs lines 315-342). Re-declaring them here just hides the base members, produces CS0108 warnings, and makes the behavior depend on whether the options instance is referenced via the base or derived type. Please remove these duplicate properties and rely on the inherited ones instead.
src/ReinforcementLearning/Agents/TabularQLearning/TabularQLearningAgent.cs
Outdated
Show resolved
Hide resolved
src/ReinforcementLearning/Agents/TabularQLearning/TabularQLearningAgent.cs
Show resolved
Hide resolved
src/ReinforcementLearning/Agents/WorldModels/WorldModelsAgent.cs
Outdated
Show resolved
Hide resolved
This commit continues expanding classical RL implementations using ReinforcementLearningAgentBase without neural networks. **New Algorithms:** 1. **ExpectedSARSAAgent<T>:** - TD control using expected value under current policy - Lower variance than SARSA - Update: Q(s,a) ← Q(s,a) + α[r + γ Σ π(a'|s')Q(s',a') - Q(s,a)] - Better performance than standard SARSA 2. **FirstVisitMonteCarloAgent<T>:** - Episode-based learning (no bootstrapping) - Uses actual returns, not estimates - Only updates first occurrence of state-action per episode - Perfect for episodic tasks with clear endings **Architecture:** All use tabular Q-tables (Dictionary<string, Dictionary<int, T>>) All inherit from ReinforcementLearningAgentBase<T> All follow project rules (no readonly, proper options inheritance) **Classical RL Progress:** ✅ Tabular Q-Learning ✅ SARSA ✅ Expected SARSA ✅ First-Visit Monte Carlo ⬜ 25+ more classical algorithms planned Partial resolution of #394
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 8
♻️ Duplicate comments (1)
src/ReinforcementLearning/Agents/ExpectedSARSA/ExpectedSARSAAgent.cs (1)
245-255: Save/Load will fail due to unimplemented serialization.These methods rely on
SerializeandDeserializewhich throwNotImplementedException. This is already flagged in the comment on lines 189-197.
🧹 Nitpick comments (3)
src/ReinforcementLearning/Agents/ExpectedSARSA/ExpectedSARSAAgent.cs (2)
124-132: Consider making state discretization precision configurable.The fixed
"F4"formatting discretizes continuous states to 4 decimal places, which causes states differing by less than 0.0001 to be treated identically. While this is a reasonable default for tabular methods in continuous spaces, making the precision configurable via options would provide more flexibility for different problem domains.
189-197: Serialization not yet implemented.The
SerializeandDeserializemethods throwNotImplementedException, which meansSaveandLoad(lines 245-255) will always fail at runtime. This limits the ability to persist trained agents.Do you want me to generate an implementation for Q-table serialization/deserialization, or should this be tracked as a follow-up task?
src/Models/Options/MonteCarloOptions.cs (1)
11-15: Avoid hiding the base epsilon schedule.
ReinforcementLearningOptions<T>already exposesEpsilonStart,EpsilonEnd, andEpsilonDecay. Redeclaring them here silently hides the base members, triggers compiler warnings, and risks divergent defaults if someone tweaks only one copy. Please drop these duplicate properties and rely on the inherited ones (or mark themnewonly if you genuinely need different defaults).
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
src/Models/Options/ExpectedSARSAOptions.cs(1 hunks)src/Models/Options/MonteCarloOptions.cs(1 hunks)src/ReinforcementLearning/Agents/ExpectedSARSA/ExpectedSARSAAgent.cs(1 hunks)src/ReinforcementLearning/Agents/MonteCarlo/FirstVisitMonteCarloAgent.cs(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
src/Models/Options/ExpectedSARSAOptions.cs (2)
src/ReinforcementLearning/Agents/ExpectedSARSA/ExpectedSARSAAgent.cs (2)
T(91-117)T(119-122)src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs (1)
ReinforcementLearningOptions(296-362)
src/ReinforcementLearning/Agents/MonteCarlo/FirstVisitMonteCarloAgent.cs (2)
src/ReinforcementLearning/Agents/ExpectedSARSA/ExpectedSARSAAgent.cs (8)
T(91-117)T(119-122)Vector(50-67)VectorToStateKey(124-132)GetBestAction(158-173)StoreExperience(69-89)GetActionIndex(134-144)EnsureStateExists(146-156)src/Models/Options/MonteCarloOptions.cs (1)
MonteCarloOptions(9-16)
src/ReinforcementLearning/Agents/ExpectedSARSA/ExpectedSARSAAgent.cs (2)
src/ReinforcementLearning/Agents/MonteCarlo/FirstVisitMonteCarloAgent.cs (17)
T(126-129)Vector(55-72)VectorToStateKey(131-139)GetBestAction(165-180)StoreExperience(74-87)GetActionIndex(141-151)EnsureStateExists(153-163)ModelMetadata(188-197)Serialize(202-205)Deserialize(207-210)Matrix(212-226)Matrix(251-254)SetParameters(228-241)IFullModel(243-249)ApplyGradients(256-256)Save(258-262)Load(264-268)src/Models/Options/ExpectedSARSAOptions.cs (1)
ExpectedSARSAOptions(9-16)
src/Models/Options/MonteCarloOptions.cs (2)
src/ReinforcementLearning/Agents/ExpectedSARSA/ExpectedSARSAAgent.cs (2)
T(91-117)T(119-122)src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs (1)
ReinforcementLearningOptions(296-362)
🔇 Additional comments (8)
src/ReinforcementLearning/Agents/ExpectedSARSA/ExpectedSARSAAgent.cs (8)
31-48: LGTM!The constructor properly validates options and initializes the Q-table and epsilon parameters.
50-67: LGTM!The epsilon-greedy action selection is correctly implemented, properly distinguishing between training and inference modes.
69-89: LGTM!The Expected SARSA update is correctly implemented, using the expected Q-value under the current policy rather than a sampled action, which reduces variance compared to standard SARSA.
119-122: LGTM!Expected SARSA performs online updates in
StoreExperience, so theTrainmethod correctly returns zero as a no-op.
146-173: LGTM!The state initialization and best action selection logic are correctly implemented, with proper handling of new states and greedy action selection.
175-187: LGTM!The metadata and properties correctly reflect the model structure. Note that
ParameterCountgrows dynamically as new states are encountered, which is expected behavior for tabular methods.
199-213: LGTM with caveat.The method correctly exports Q-values to a matrix, but note that state keys are lost in this representation. This is acceptable for parameter inspection but limits the utility of the corresponding
SetParametersmethod.
238-243: LGTM!Tabular Expected SARSA doesn't use gradient-based optimization, so these stub implementations are appropriate.
src/ReinforcementLearning/Agents/ExpectedSARSA/ExpectedSARSAAgent.cs
Outdated
Show resolved
Hide resolved
src/ReinforcementLearning/Agents/MonteCarlo/FirstVisitMonteCarloAgent.cs
Outdated
Show resolved
Hide resolved
src/ReinforcementLearning/Agents/MonteCarlo/FirstVisitMonteCarloAgent.cs
Show resolved
Hide resolved
Added more classical RL algorithms using ReinforcementLearningAgentBase. New algorithms: - DoubleQLearningAgent: Reduces overestimation bias with two Q-tables Progress: 7/29 classical RL algorithms implemented Partial resolution of #394
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/Models/Options/DoubleQLearningOptions.cs(1 hunks)src/ReinforcementLearning/Agents/DoubleQLearning/DoubleQLearningAgent.cs(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/Models/Options/DoubleQLearningOptions.cs (1)
src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs (1)
ReinforcementLearningOptions(296-362)
src/ReinforcementLearning/Agents/DoubleQLearning/DoubleQLearningAgent.cs (1)
src/Models/Options/DoubleQLearningOptions.cs (1)
DoubleQLearningOptions(9-16)
src/ReinforcementLearning/Agents/DoubleQLearning/DoubleQLearningAgent.cs
Outdated
Show resolved
Hide resolved
src/ReinforcementLearning/Agents/DoubleQLearning/DoubleQLearningAgent.cs
Show resolved
Hide resolved
Added n-step SARSA agent that uses multi-step bootstrapping for better credit assignment. Progress: 6/29 classical RL algorithms Partial resolution of #394
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/Models/Options/NStepSARSAOptions.cs(1 hunks)src/ReinforcementLearning/Agents/NStepSARSA/NStepSARSAAgent.cs(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/Models/Options/NStepSARSAOptions.cs (2)
src/ReinforcementLearning/Agents/NStepSARSA/NStepSARSAAgent.cs (1)
T(131-134)src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs (1)
ReinforcementLearningOptions(296-362)
src/ReinforcementLearning/Agents/NStepSARSA/NStepSARSAAgent.cs (1)
src/Models/Options/NStepSARSAOptions.cs (1)
NStepSARSAOptions(9-17)
…ing implementations - Fixed options classes: replaced collection expression syntax with old-style initializers (MADDPGOptions, QMIXOptions, MuZeroOptions, WorldModelsOptions) - Fixed RainbowDQN: consistent use of _options field throughout implementation - Added missing abstract method implementations to 6 agents (TRPO, DecisionTransformer, MADDPG, QMIX, Dreamer, MuZero, WorldModels) - All agents now implement: GetModelMetadata, FeatureCount, Serialize/Deserialize, GetParameters/SetParameters, Clone, ComputeGradients, ApplyGradients, Save/Load - Added SequenceContext<T> helper class for DecisionTransformer - Fixed generic type parameter in DecisionTransformer.ResetEpisode() - Added classical RL implementations: EveryVisitMonteCarloAgent, NStepQLearningAgent All changes ensure .NET Framework compatibility (no readonly, no collection expressions)
- Monte Carlo Exploring Starts: ensures exploration via random starts - On-Policy Monte Carlo Control: epsilon-greedy exploration - Off-Policy Monte Carlo Control: weighted importance sampling - Policy Iteration: iterative policy evaluation and improvement - Value Iteration: Bellman optimality equation implementation All implementations follow .NET Framework compatibility (no readonly, no collection expressions) Progress: 13/29 classical RL algorithms completed
Replace temporary gradient calculation in MADDPGAgent.cs:304-325 with production-ready pattern following DoubleDQNAgent approach. Changes: - Use LossFunction.CalculateDerivative() for output gradients - Call Backpropagate() to propagate error through critic network - Apply gradient descent to all network parameters - Update parameters via UpdateParameters() method This resolves PR review comment about critic using squared error gradient instead of proper per-parameter gradient descent. Co-Authored-By: Claude <[email protected]>
Replace incorrect gradient application logic with NotSupportedException following the established pattern from DoubleDQNAgent. The previous implementation assumed gradients.Length matched currentParams.Length, which would cause IndexOutOfRangeException since ComputeGradients returns output-space gradients (not parameter-space gradients). MADDPG manages gradient updates internally through backpropagation in the Train() method (lines 304-325), so external gradient application via ApplyGradients is not supported. The exception message clearly documents this and directs users to use the internal Train() loop instead.
Replace 'new Random()' with 'Random' (base class property) to preserve seeded/deterministic behavior from ReinforcementLearningAgentBase. Creating a new Random() instance breaks reproducibility when a seed is provided in the agent options. This ensures epsilon-greedy exploration uses the same seeded Random instance as the rest of the agent's stochastic operations, maintaining deterministic behavior for testing and reproducibility.
Replace new Random() with inherited Random instance to preserve seeded/deterministic behavior from ReinforcementLearningAgentBase options.
Replace new Random() with inherited Random instance to preserve seeded/deterministic behavior for reproducible testing.
Remove duplicate private Random field and use protected Random property from ReinforcementLearningAgentBase<T>. This ensures the agent uses the seeded RNG for reproducible testing. Changes: - Remove private Random _random field declaration - Remove _random initialization in constructor - Replace _random.NextDouble() with Random.NextDouble() - Replace _random.Next() with Random.Next() Fixes PR #481 review comment
Add _options.Validate() call immediately after options assignment in IQLAgent constructor. This ensures invalid IQLOptions settings fail fast and deterministically before any networks or buffers are created. Fixes PR #481 review comment at src/ReinforcementLearning/Agents/IQL/IQLAgent.cs:70
Call _options.Validate() immediately after assignment and null check in MADDPGAgent constructor (line 60) before any use of options properties (NumAgents, ReplayBufferSize, etc.) to ensure invalid configurations fail fast and deterministically before resource creation. Addresses PR #481 review comment.
Pass the seeded _random instance to MathHelper.GetNormalRandom() at line 210 to ensure deterministic Gaussian noise sampling for reproducible training results. Without this change, the agent would create new unseeded Random instances internally, making it impossible to achieve reproducible results even with a seed. Addresses PR #481 review comment.
Add optional Random parameter to MathHelper.GetNormalRandom() method to allow callers to pass in a seeded Random instance, preventing the creation of new unseeded Random instances on each call which breaks reproducibility. Update IQLAgent.cs to pass in its seeded _random instance when calling GetNormalRandom() at line 210, ensuring deterministic behavior for reproducible testing. Without this fix, each call to GetNormalRandom created a new Random() instance, making it impossible to achieve reproducible results even when agents were initialized with a seed. Addresses PR #481 review comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 8
♻️ Duplicate comments (17)
src/PredictionModelBuilder.cs (2)
641-642: Remove the pragma and fix the async method signature.The
#pragma warning disable CS1998directive still suppresses the compiler warning about this async method lacking await operators, which was flagged in previous reviews. The method runs synchronously and blocks the calling thread during training.As previously recommended: either remove
asyncand returnTask.FromResult(result), or wrap the training loop inTask.Runand await it.
697-700: Fix the loss filter to preserve negative losses.The condition
if (numOps.ToDouble(loss) > 0)incorrectly excludes valid negative losses from tracking. As flagged in previous reviews, RL agents like DDPG produce negative actor losses, and these are valid training outputs.Change line 697 to exclude only the zero sentinel:
- if (numOps.ToDouble(loss) > 0) + if (!numOps.Equals(loss, numOps.Zero))src/ReinforcementLearning/Agents/ExpectedSARSA/ExpectedSARSAAgent.cs (1)
268-286: Critical: SetParameters clears Q-table before capturing state keys.Lines 271-273 clear
_qTablethen attempt to read its keys, resulting in an emptystateKeyslist. The reconstruction loop (lines 276-285) never executes, making this method completely non-functional. This critical bug was flagged in previous reviews but the fix was not correctly applied.Capture state keys before clearing:
public override void SetParameters(Vector<T> parameters) { + // Capture existing state keys before clearing + var stateKeys = _qTable.Keys.ToList(); + // Reconstruct Q-table from vector _qTable.Clear(); - var stateKeys = _qTable.Keys.ToList(); int maxStates = parameters.Length / _options.ActionSize;Note:
GetParameters()loses state key ordering, so this method can only restore Q-values if the original state keys are preserved. Consider returning state keys alongside parameters or documenting this limitation.src/ReinforcementLearning/Agents/DuelingDQN/DuelingDQNAgent.cs (3)
58-78: Null check occurs afteroptionsis dereferenced inbase()call.The constructor dereferences
optionsfields (lines 61-71) in thebase()argument before checking for null (line 74). If a caller passesnull, aNullReferenceExceptionwill be thrown from the base constructor invocation instead of the intendedArgumentNullException.Based on past review comments, this issue was previously identified but remains unresolved. Apply the recommended fix by introducing a static helper:
+ private static ReinforcementLearningOptions<T> CreateBaseOptions(DuelingDQNOptions<T> options) + { + if (options == null) + throw new ArgumentNullException(nameof(options)); + + return new ReinforcementLearningOptions<T> + { + LearningRate = options.LearningRate, + DiscountFactor = options.DiscountFactor, + LossFunction = options.LossFunction, + Seed = options.Seed, + BatchSize = options.BatchSize, + ReplayBufferSize = options.ReplayBufferSize, + TargetUpdateFrequency = options.TargetUpdateFrequency, + WarmupSteps = options.WarmupSteps, + EpsilonStart = options.EpsilonStart, + EpsilonEnd = options.EpsilonEnd, + EpsilonDecay = options.EpsilonDecay + }; + } + public DuelingDQNAgent(DuelingDQNOptions<T> options) - : base(new ReinforcementLearningOptions<T> - { - LearningRate = options.LearningRate, - DiscountFactor = options.DiscountFactor, - LossFunction = options.LossFunction, - Seed = options.Seed, - BatchSize = options.BatchSize, - ReplayBufferSize = options.ReplayBufferSize, - TargetUpdateFrequency = options.TargetUpdateFrequency, - WarmupSteps = options.WarmupSteps, - EpsilonStart = options.EpsilonStart, - EpsilonEnd = options.EpsilonEnd, - EpsilonDecay = options.EpsilonDecay - }) + : base(CreateBaseOptions(options)) { - _options = options ?? throw new ArgumentNullException(nameof(options)); + _options = options; _replayBuffer = new UniformReplayBuffer<T>(options.ReplayBufferSize, options.Seed);
236-255:Deserializediscards serialized hyperparameters, leaving_optionsstale.Lines 241-244 read
StateSize,ActionSize,LearningRate, andDiscountFactorfrom the serialized data but immediately discard them. The_optionsfield is never updated, so a deserialized agent retains its original constructor configuration rather than the saved state. Since training logic relies on_optionsfields likeBatchSize(line 131, 178),TargetUpdateFrequency(line 181), andEpsilonEnd(line 186), a loaded agent may behave inconsistently with its trained state.Based on past review comments, this issue was previously identified but remains unresolved. Consider one of these approaches:
- Reconstruct
_optionsfrom deserialized values (if you want full state restoration)- Validate that deserialized hyperparameters match the current
_optionsand throw an exception on mismatch- Remove hyperparameter serialization if they're not intended to be restored
Recommended approach (validation):
public override void Deserialize(byte[] data) { using var ms = new MemoryStream(data); using var reader = new BinaryReader(ms); - reader.ReadInt32(); // stateSize - reader.ReadInt32(); // actionSize - reader.ReadDouble(); // learningRate - reader.ReadDouble(); // discountFactor + var stateSize = reader.ReadInt32(); + var actionSize = reader.ReadInt32(); + var learningRate = reader.ReadDouble(); + var discountFactor = reader.ReadDouble(); + + // Validate architecture matches + if (stateSize != _options.StateSize || actionSize != _options.ActionSize) + throw new InvalidOperationException($"Architecture mismatch: expected StateSize={_options.StateSize}, ActionSize={_options.ActionSize}, but got StateSize={stateSize}, ActionSize={actionSize}"); + + // Optionally validate hyperparameters + if (Math.Abs(learningRate - NumOps.ToDouble(LearningRate)) > 1e-6 || + Math.Abs(discountFactor - NumOps.ToDouble(DiscountFactor)) > 1e-6) + { + // Log warning or throw exception + } + _epsilon = reader.ReadDouble();
268-274:SetParametersleaves the target network stale after external parameter injection.
SetParametersupdates only_qNetwork(line 273) and never copies those parameters into_targetNetwork. Until the next scheduled sync inTrain()(line 181-183), the target network will have inconsistent weights, which is surprising behavior if callers expectSetParametersto fully define the agent's state.Based on past review comments, this issue was previously identified but remains unresolved. Mirror the
DoubleDQNAgentpattern (see relevant code snippets) and sync the target network:public override void SetParameters(Vector<T> parameters) { var matrix = new Matrix<T>(parameters.Length, 1); for (int i = 0; i < parameters.Length; i++) matrix[i, 0] = parameters[i]; _qNetwork.SetFlattenedParameters(matrix); + CopyNetworkWeights(_qNetwork, _targetNetwork); }src/ReinforcementLearning/Agents/DoubleDQN/DoubleDQNAgent.cs (2)
334-350: ComputeGradients returns output-space derivatives instead of parameter gradients.Line 349 returns
gradient(∂Loss/∂output, size ≈ ActionSize) rather than the accumulated parameter gradients (∂Loss/∂θ) after backpropagation. External callers expecting parameter gradients will receive the wrong tensor shape and semantics.
278-281: Deserialize discards serialized hyperparameters, breaking model persistence.Lines 278-281 read
stateSize,actionSize,learningRate, anddiscountFactorbut immediately discard them. A deserialized agent retains constructor defaults instead of the saved configuration, so a checkpoint saved with LR=0.001 loads with whatever_options.LearningRatewas at construction.Apply this diff to validate structure and restore scalars:
- reader.ReadInt32(); // stateSize - reader.ReadInt32(); // actionSize - reader.ReadDouble(); // learningRate - reader.ReadDouble(); // discountFactor + var stateSize = reader.ReadInt32(); + if (stateSize != _options.StateSize) + { + throw new InvalidOperationException( + $"Serialized state size {stateSize} does not match configured {_options.StateSize}."); + } + + var actionSize = reader.ReadInt32(); + if (actionSize != _options.ActionSize) + { + throw new InvalidOperationException( + $"Serialized action size {actionSize} does not match configured {_options.ActionSize}."); + } + + var learningRate = reader.ReadDouble(); + LearningRate = NumOps.FromDouble(learningRate); + + var discountFactor = reader.ReadDouble(); + DiscountFactor = NumOps.FromDouble(discountFactor); _epsilon = reader.ReadDouble(); _steps = reader.ReadInt32();src/ReinforcementLearning/Agents/IQL/IQLAgent.cs (3)
568-574: Critical: ComputeGradients returns parameter values instead of gradients.The method returns
GetParameters()which are the model weights, not the gradients (derivatives) that the method contract requires. This is semantically incorrect.For an offline RL agent where supervised gradient computation doesn't naturally apply, either throw
NotSupportedExceptionor return zero gradients:Option 1: Throw NotSupportedException (recommended)
public override Vector<T> ComputeGradients( Vector<T> input, Vector<T> target, ILossFunction<T>? lossFunction = null) { - return GetParameters(); + throw new NotSupportedException( + "IQL uses specialized offline RL gradients computed in Train(). " + + "Use Train() for gradient-based updates."); }Option 2: Return zero gradients
public override Vector<T> ComputeGradients( Vector<T> input, Vector<T> target, ILossFunction<T>? lossFunction = null) { - return GetParameters(); + // IQL uses internal gradient computation via Train() + return new Vector<T>(GetParameters().Length); }
378-430: Critical: Policy gradient still missing tanh chain rule and factor of 2.Despite past review comments marking this as addressed, the policy gradient at lines 414-420 remains incorrect. The gradient
weight * (predictedAction - action)omits both:
- The factor of 2 from the MSE derivative
d/dx(error²) = 2 * error- The tanh chain rule
d/dmean tanh(mean) = 1 - tanh²(mean)Additionally, calling
SelectActionat line 403 for every experience performs redundant forward passes.For loss
L = weight * (tanh(mean) - action)², the correct gradient w.r.t. the pre-activationmeanis:dL/dmean = 2 * weight * (tanh(mean) - action) * (1 - tanh(mean)²)Apply this diff to fix the gradient computation and eliminate redundant forward passes:
private T UpdatePolicy(List<ReplayBuffers.Experience<T>> batch) { T totalLoss = _numOps.Zero; foreach (var experience in batch) { // Compute advantage: A(s,a) = Q(s,a) - V(s) var stateAction = ConcatenateStateAction(experience.State, experience.Action); var stateActionTensor = Tensor<T>.FromVector(stateAction); var q1OutputTensor = _q1Network.Predict(stateActionTensor); var q1Value = q1OutputTensor.ToVector()[0]; var q2OutputTensor = _q2Network.Predict(stateActionTensor); var q2Value = q2OutputTensor.ToVector()[0]; var qValue = MathHelper.Min<T>(q1Value, q2Value); var stateTensor = Tensor<T>.FromVector(experience.State); var vOutputTensor = _valueNetwork.Predict(stateTensor); var vValue = vOutputTensor.ToVector()[0]; var advantage = _numOps.Subtract(qValue, vValue); // Advantage-weighted regression: exp(advantage / temperature) * log_prob(a|s) var weight = NumOps.Exp(_numOps.Divide(advantage, _options.Temperature)); weight = MathHelper.Clamp<T>(weight, _numOps.FromDouble(0.0), _numOps.FromDouble(100.0)); - // Simplified policy loss (weighted MSE to match action) - var predictedAction = SelectAction(experience.State, training: false); + // Policy loss (weighted MSE to match action) + var policyOutputTensor = _policyNetwork.Predict(stateTensor); + var policyOutput = policyOutputTensor.ToVector(); + + var mean = new Vector<T>(_options.ActionSize); + for (int i = 0; i < _options.ActionSize; i++) + { + mean[i] = policyOutput[i]; + } + T actionDiff = _numOps.Zero; + var gradientVec = new Vector<T>(_options.ActionSize * 2); + for (int i = 0; i < _options.ActionSize; i++) { - var diff = _numOps.Subtract(experience.Action[i], predictedAction[i]); + var tanhMean = MathHelper.Tanh<T>(mean[i]); + var diff = _numOps.Subtract(tanhMean, experience.Action[i]); actionDiff = _numOps.Add(actionDiff, _numOps.Multiply(diff, diff)); + + // MSE gradient: 2 * weight * (tanh(mean) - action) + var mseGrad = _numOps.Multiply(_numOps.FromDouble(2.0), _numOps.Multiply(weight, diff)); + // Tanh derivative: (1 - tanh(mean)^2) + var tanhDeriv = _numOps.Subtract(_numOps.One, _numOps.Multiply(tanhMean, tanhMean)); + // Chain rule: dL/dmean = mseGrad * tanhDeriv + gradientVec[i] = _numOps.Multiply(mseGrad, tanhDeriv); } var policyLoss = _numOps.Multiply(weight, actionDiff); totalLoss = _numOps.Add(totalLoss, policyLoss); - // Backpropagate - var gradientVec = new Vector<T>(_options.ActionSize * 2); - for (int i = 0; i < _options.ActionSize; i++) - { - var diff = _numOps.Subtract(predictedAction[i], experience.Action[i]); - gradientVec[i] = _numOps.Multiply(weight, diff); - } - var gradientTensor = Tensor<T>.FromVector(gradientVec); _policyNetwork.Backpropagate(gradientTensor); var gradients = _policyNetwork.GetParameterGradients(); _policyNetwork.ApplyGradients(gradients, _options.PolicyLearningRate); } return _numOps.Divide(totalLoss, _numOps.FromDouble(batch.Count)); }
205-213: Critical: GetNormalRandom does not accept 3 arguments.Line 210 calls
MathHelper.GetNormalRandom<T>(_numOps.Zero, _numOps.One, _random), but the method only accepts 2 parameters (mean and stdDev), causing a compilation failure. Additionally,GetNormalRandomcreates a newRandominstance internally, breaking reproducibility and ignoring the agent's seeded RNG.Add a helper method that uses the agent's
_randominstance:+ private T SampleStandardNormal() + { + var u1 = 1.0 - _random.NextDouble(); + var u2 = 1.0 - _random.NextDouble(); + var randStdNormal = Math.Sqrt(-2.0 * Math.Log(u1)) * Math.Sin(2.0 * Math.PI * u2); + return _numOps.FromDouble(randStdNormal); + } + // Sample from Gaussian policy var action = new Vector<T>(_options.ActionSize); for (int i = 0; i < _options.ActionSize; i++) { var std = NumOps.Exp(logStd[i]); - var noise = MathHelper.GetNormalRandom<T>(_numOps.Zero, _numOps.One, _random); + var noise = SampleStandardNormal(); var rawAction = _numOps.Add(mean[i], _numOps.Multiply(std, noise)); action[i] = MathHelper.Tanh<T>(rawAction); }Don't forget to add
using System;at the top if not already present.src/ReinforcementLearning/Agents/MADDPG/MADDPGAgent.cs (1)
637-648:ComputeGradientsuses supervised learning pattern without documentation.This method compares
Predict(input)against atargetusing a loss function, which is the supervised learning paradigm. In MADDPG, gradients come from the deterministic policy gradient (backpropagating Q-values), not input-target comparison.The method exists for
IFullModelinterface compliance but will produce incorrect gradients if invoked during RL training. Previous review suggested either throwingNotSupportedExceptionor documenting that it's for supervised fine-tuning only, but the current implementation does neither.Apply one of these fixes:
Option 1 (recommended): Throw
NotSupportedExceptionpublic override Vector<T> ComputeGradients( Vector<T> input, Vector<T> target, ILossFunction<T>? lossFunction = null) { - var prediction = Predict(input); - var usedLossFunction = lossFunction ?? LossFunction; - var loss = usedLossFunction.CalculateLoss(prediction, target); - - var gradient = usedLossFunction.CalculateDerivative(prediction, target); - return gradient; + throw new NotSupportedException( + "ComputeGradients is not applicable to MADDPG. Use Train() for RL updates."); }Option 2: Add clear documentation
+ /// <summary> + /// Computes supervised learning gradients. NOT used for MADDPG RL training. + /// </summary> + /// <remarks> + /// This method is for supervised fine-tuning only. MADDPG RL training uses Train(). + /// </remarks> public override Vector<T> ComputeGradients(src/ReinforcementLearning/Agents/MonteCarlo/MonteCarloExploringStartsAgent.cs (2)
296-310: SetParameters only updates existing entries — throw NotSupportedException.
SetParametersiterates over existing_qTableentries. If the Q-table is empty (e.g., right after construction), this method is effectively a no-op. For tabular agents, reconstructing state keys from a flat parameter vector is not feasible.Replace with an explicit exception:
public override void SetParameters(Vector<T> parameters) { - int index = 0; - foreach (var stateEntry in _qTable.ToList()) - { - for (int a = 0; a < _options.ActionSize; a++) - { - if (index < parameters.Length) - { - _qTable[stateEntry.Key][a] = parameters[index]; - index++; - } - } - } + throw new NotSupportedException( + "SetParameters is not supported for tabular Monte Carlo agents. " + + "State keys cannot be recovered from parameter vectors. " + + "Use Clone() or Serialize/Deserialize for state preservation."); }
196-200: Predict mutates training state — introduce separate read-only prediction path.Line 198 sets
_isFirstAction = false, mutating training state in what should be a read-only prediction method. IfPredictis called during an episode, it will disable the exploring start for that episode, corrupting the Monte Carlo Exploring Starts guarantee.Consider extracting greedy action selection into a separate method:
public override Vector<T> Predict(Vector<T> input) { - _isFirstAction = false; - return SelectAction(input, training: false); + return SelectGreedyAction(input); } + +private Vector<T> SelectGreedyAction(Vector<T> state) +{ + EnsureStateExists(state); + string stateKey = GetStateKey(state); + + int bestAction = 0; + T bestValue = _qTable[stateKey][0]; + + for (int a = 1; a < _options.ActionSize; a++) + { + if (NumOps.GreaterThan(_qTable[stateKey][a], bestValue)) + { + bestValue = _qTable[stateKey][a]; + bestAction = a; + } + } + + var result = new Vector<T>(_options.ActionSize); + result[bestAction] = NumOps.One; + return result; +}src/ReinforcementLearning/Agents/EligibilityTraces/SARSALambdaAgent.cs (3)
53-92: Clarify parameter semantics: unusednextStateparameter causes confusion.The
nextStateparameter is never used in this method. The SARSA(λ) implementation uses a shifted calling convention wherestaterepresents the next state from the previous transition, andactionrepresents the next action. This differs from standard RL signatures likeStoreExperience(s, a, r, s', done)and makes the code confusing for callers.Consider either:
- Removing the unused
nextStateparameter and documenting the shifted semantics in XML comments, or- Refactoring to match standard parameter semantics without relying on
_lastState/_lastActionstate tracking.If keeping the current approach, add XML documentation to clarify:
+ /// <summary> + /// Stores experience for SARSA(λ) learning. + /// Note: Due to SARSA's on-policy nature, this method uses shifted semantics: + /// - 'state' is the state reached after taking the previous action + /// - 'action' is the action taken at 'state' + /// - 'reward' is the reward received from the previous transition + /// - 'nextState' is unused (kept for interface compatibility) + /// </summary> public override void StoreExperience(Vector<T> state, Vector<T> action, T reward, Vector<T> nextState, bool done)
18-25: Ineffective null check: use a boolean flag to track first experience.Line 24 initializes
_lastStateto a newVector<T>, notnull. This means the null check at line 55 will always evaluate to true, even on the first call toStoreExperiencein an episode. This causes incorrect TD updates for the first transition.Apply this diff to use a boolean flag:
private Vector<T> _lastState; private int _lastAction; + private bool _isFirstExperience; public SARSALambdaAgent(SARSALambdaOptions<T> options) : base(options) { _options = options ?? throw new ArgumentNullException(nameof(options)); _qTable = new Dictionary<string, Dictionary<int, T>>(); _eligibilityTraces = new Dictionary<string, Dictionary<int, T>>(); _epsilon = options.EpsilonStart; - _lastState = new Vector<T>(options.StateSize); + _isFirstExperience = true; }Then update
StoreExperienceat line 55:- if (_lastState != null) + if (!_isFirstExperience) { // ... existing logic ... } _lastState = state; _lastAction = ArgMax(action); + _isFirstExperience = false; if (done) { ResetEpisode(); _epsilon = Math.Max(_options.EpsilonEnd, _epsilon * _options.EpsilonDecay); + _isFirstExperience = true; }Also update
ResetEpisodeat line 129:- public override void ResetEpisode() { _lastState = new Vector<T>(_options.StateSize); foreach (var s in _eligibilityTraces.Keys.ToList()) { for (int a = 0; a < _options.ActionSize; a++) _eligibilityTraces[s][a] = NumOps.Zero; } } + public override void ResetEpisode() { _isFirstExperience = true; foreach (var s in _eligibilityTraces.Keys.ToList()) { for (int a = 0; a < _options.ActionSize; a++) _eligibilityTraces[s][a] = NumOps.Zero; } }
182-182: Add validation to SetParameters to prevent silent failures.The method lacks validation and could fail silently or produce incorrect behavior if the
parametersvector has incompatible dimensions. A past review comment flagged this issue but it remains unaddressed.Apply this diff to add validation:
- public override void SetParameters(Vector<T> parameters) { int idx = 0; foreach (var s in _qTable.ToList()) for (int a = 0; a < _options.ActionSize; a++) if (idx < parameters.Length) _qTable[s.Key][a] = parameters[idx++]; } + public override void SetParameters(Vector<T> parameters) + { + if (parameters == null) + throw new ArgumentNullException(nameof(parameters)); + + int expectedCount = _qTable.Count * _options.ActionSize; + if (parameters.Length != expectedCount) + throw new ArgumentException($"Expected parameter vector of length {expectedCount}, got {parameters.Length}", nameof(parameters)); + + int idx = 0; + foreach (var s in _qTable.ToList()) + for (int a = 0; a < _options.ActionSize; a++) + _qTable[s.Key][a] = parameters[idx++]; + }
🧹 Nitpick comments (5)
src/PredictionModelBuilder.cs (1)
777-788: Consider extracting ComputeAverage to a shared utility.This method duplicates
ComputeAveragefromReinforcementLearningAgentBase.cs(lines 328-339). Since both classes need generic numeric averaging, consider extracting this toMathHelperas a static extension method to eliminate duplication.Example extraction to MathHelper:
// In MathHelper.cs public static T ComputeAverage<T>(this INumericOperations<T> numOps, IEnumerable<T> values) { var list = values.ToList(); if (list.Count == 0) return numOps.Zero; T sum = numOps.Zero; foreach (var value in list) { sum = numOps.Add(sum, value); } return numOps.Divide(sum, numOps.FromDouble(list.Count)); }Then use as:
var avgReward = numOps.ComputeAverage(recentRewards);src/ReinforcementLearning/Agents/DoubleDQN/DoubleDQNAgent.cs (1)
360-365: LGTM: ApplyGradients now prevents misuse via NotSupportedException.Throwing prevents the
IndexOutOfRangeExceptionfrom dimension mismatch. However, note thatComputeGradients(lines 334-350) still returns output-space gradients, creating an inconsistent API: one method returns a result while its counterpart refuses to accept it. Consider also throwing inComputeGradientsif external gradient APIs are unsupported, or fix both to handle parameter gradients consistently.src/ReinforcementLearning/Agents/IQL/IQLAgent.cs (2)
162-165: Remove unused InitializeBuffer method.This method is never called and the buffer is already initialized in the constructor at line 85.
- private void InitializeBuffer() - { - _offlineBuffer = new UniformReplayBuffer<T>(_options.BufferSize); - }
577-581: Consider throwing NotSupportedException for consistency.If
ComputeGradientsthrowsNotSupportedException, this method should also throw for consistency, as both are part of the supervised learning interface that doesn't apply to offline RL.public override void ApplyGradients(Vector<T> gradients, T learningRate) { - // IQL uses offline training with separate network updates - // Gradient application is handled by individual network updates + throw new NotSupportedException( + "IQL uses specialized offline RL gradient updates. " + + "Use Train() to update the agent."); }src/ReinforcementLearning/Agents/EligibilityTraces/SARSALambdaAgent.cs (1)
204-205: Document that gradient methods are not used for SARSA(λ) learning.A previous review noted that
ComputeGradientsandApplyGradientsare misleading for Q-learning agents because SARSA(λ) uses temporal-difference (TD) updates, not gradient descent. While the comment was marked as addressed, the methods remain present (likely for base class interface compliance).Consider adding XML documentation to clarify:
+ /// <summary> + /// Computes gradients for interface compliance. Note: SARSA(λ) uses TD updates, not gradient descent. + /// This method is not used for actual learning. + /// </summary> public override Vector<T> ComputeGradients(Vector<T> input, Vector<T> target, ILossFunction<T>? lossFunction = null) { var pred = Predict(input); var lf = lossFunction ?? LossFunction; var loss = lf.CalculateLoss(pred, target); var grad = lf.CalculateDerivative(pred, target); return grad; } + + /// <summary> + /// No-op for SARSA(λ). Learning occurs via TD updates in StoreExperience. + /// </summary> public override void ApplyGradients(Vector<T> gradients, T learningRate) { }
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (11)
src/Models/Options/ExpectedSARSAOptions.cs(1 hunks)src/PredictionModelBuilder.cs(3 hunks)src/ReinforcementLearning/Agents/DoubleDQN/DoubleDQNAgent.cs(1 hunks)src/ReinforcementLearning/Agents/DuelingDQN/DuelingDQNAgent.cs(1 hunks)src/ReinforcementLearning/Agents/EligibilityTraces/SARSALambdaAgent.cs(1 hunks)src/ReinforcementLearning/Agents/ExpectedSARSA/ExpectedSARSAAgent.cs(1 hunks)src/ReinforcementLearning/Agents/IQL/IQLAgent.cs(1 hunks)src/ReinforcementLearning/Agents/MADDPG/MADDPGAgent.cs(1 hunks)src/ReinforcementLearning/Agents/MonteCarlo/EveryVisitMonteCarloAgent.cs(1 hunks)src/ReinforcementLearning/Agents/MonteCarlo/FirstVisitMonteCarloAgent.cs(1 hunks)src/ReinforcementLearning/Agents/MonteCarlo/MonteCarloExploringStartsAgent.cs(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- src/ReinforcementLearning/Agents/MonteCarlo/EveryVisitMonteCarloAgent.cs
🧰 Additional context used
🧬 Code graph analysis (10)
src/ReinforcementLearning/Agents/MADDPG/MADDPGAgent.cs (5)
src/ReinforcementLearning/Agents/DeepReinforcementLearningAgentBase.cs (2)
DeepReinforcementLearningAgentBase(33-105)DeepReinforcementLearningAgentBase(65-69)src/Models/Options/MADDPGOptions.cs (3)
MADDPGOptions(33-73)MADDPGOptions(53-59)Validate(64-72)src/NeuralNetworks/FeedForwardNeuralNetwork.cs (5)
Tensor(131-146)Tensor(166-178)Tensor(199-208)Train(258-292)UpdateParameters(226-239)src/NeuralNetworks/NeuralNetwork.cs (3)
Tensor(179-197)Train(228-262)UpdateParameters(141-154)src/Helpers/MathHelper.cs (1)
MathHelper(19-992)
src/ReinforcementLearning/Agents/MonteCarlo/FirstVisitMonteCarloAgent.cs (6)
src/ReinforcementLearning/Agents/MonteCarlo/MonteCarloExploringStartsAgent.cs (15)
T(88-92)T(148-162)Dictionary(181-188)Vector(42-73)Vector(196-200)Vector(271-294)Vector(337-348)StoreExperience(75-86)UpdateFromEpisode(94-127)EnsureStateExists(129-141)ResetEpisode(190-194)Serialize(225-236)Deserialize(238-269)SetParameters(296-310)IFullModel(312-335)src/ReinforcementLearning/Agents/MonteCarlo/OnPolicyMonteCarloAgent.cs (15)
T(90-94)T(150-164)Dictionary(183-191)Vector(42-73)Vector(198-201)Vector(258-281)Vector(323-334)StoreExperience(75-88)UpdateFromEpisode(96-129)EnsureStateExists(131-143)ResetEpisode(193-196)Serialize(226-237)Deserialize(239-256)SetParameters(283-297)IFullModel(299-321)src/ReinforcementLearning/Agents/MonteCarlo/OffPolicyMonteCarloAgent.cs (14)
T(85-89)Dictionary(190-197)Vector(40-71)Vector(204-207)Vector(262-285)Vector(321-332)StoreExperience(73-83)UpdateFromEpisode(91-129)EnsureStateExists(151-166)ResetEpisode(199-202)Serialize(232-242)Deserialize(244-260)SetParameters(287-301)IFullModel(303-319)src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs (3)
ReinforcementLearningAgentBase(34-415)ReinforcementLearningAgentBase(90-111)Train(170-175)src/Models/Options/MonteCarloOptions.cs (1)
MonteCarloOptions(9-13)src/ReinforcementLearning/Agents/MonteCarlo/EveryVisitMonteCarloAgent.cs (3)
VectorToStateKey(106-114)UpdateFromEpisode(75-97)EnsureStateExists(141-151)
src/ReinforcementLearning/Agents/ExpectedSARSA/ExpectedSARSAAgent.cs (2)
src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs (3)
ReinforcementLearningAgentBase(34-415)ReinforcementLearningAgentBase(90-111)Train(170-175)src/Models/Options/ExpectedSARSAOptions.cs (2)
ExpectedSARSAOptions(9-87)ExpectedSARSAOptions(22-42)
src/PredictionModelBuilder.cs (3)
src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs (6)
T(137-137)T(329-340)ResetEpisode(142-145)StoreExperience(131-131)Train(170-175)IEnumerable(234-237)src/Helpers/MathHelper.cs (2)
MathHelper(19-992)INumericOperations(36-66)src/ReinforcementLearning/Interfaces/IRLAgent.cs (2)
ResetEpisode(76-76)StoreExperience(55-55)
src/ReinforcementLearning/Agents/DoubleDQN/DoubleDQNAgent.cs (3)
src/ReinforcementLearning/Agents/DQN/DQNAgent.cs (17)
T(165-247)T(420-433)NeuralNetwork(103-131)CopyNetworkWeights(397-401)Vector(134-155)Vector(322-325)Vector(362-378)ArgMax(403-418)StoreExperience(158-162)Dictionary(250-257)Serialize(270-294)Deserialize(297-319)SetParameters(328-333)IFullModel(336-359)ApplyGradients(381-393)SaveModel(435-439)LoadModel(442-446)src/Models/Options/DoubleDQNOptions.cs (2)
DoubleDQNOptions(9-32)DoubleDQNOptions(26-31)src/NeuralNetworks/NeuralNetwork.cs (3)
Tensor(179-197)Train(228-262)UpdateParameters(141-154)
src/ReinforcementLearning/Agents/EligibilityTraces/SARSALambdaAgent.cs (3)
src/ReinforcementLearning/Agents/EligibilityTraces/WatkinsQLambdaAgent.cs (18)
T(100-100)Dictionary(101-101)Vector(24-32)Vector(103-103)Vector(140-154)Vector(175-175)EnsureStateExists(81-94)GetStateKey(96-96)ArgMax(98-98)StoreExperience(34-79)ResetEpisode(102-102)Serialize(109-120)Deserialize(122-139)SetParameters(155-155)IFullModel(156-174)ApplyGradients(176-176)SaveModel(177-177)LoadModel(178-178)src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs (3)
ReinforcementLearningAgentBase(34-415)ReinforcementLearningAgentBase(90-111)Train(170-175)src/Models/Options/SARSALambdaOptions.cs (1)
SARSALambdaOptions(5-10)
src/ReinforcementLearning/Agents/MonteCarlo/MonteCarloExploringStartsAgent.cs (5)
src/ReinforcementLearning/Agents/MonteCarlo/FirstVisitMonteCarloAgent.cs (15)
T(129-132)Vector(58-75)Vector(234-250)Vector(308-311)EnsureStateExists(156-166)StoreExperience(77-90)UpdateFromEpisode(92-127)ResetEpisode(185-189)Serialize(202-213)Deserialize(215-232)SetParameters(252-282)IFullModel(284-306)ApplyGradients(313-313)SaveModel(315-324)LoadModel(326-340)src/ReinforcementLearning/Agents/MonteCarlo/OnPolicyMonteCarloAgent.cs (17)
T(90-94)T(150-164)Vector(42-73)Vector(198-201)Vector(258-281)Vector(323-334)EnsureStateExists(131-143)StoreExperience(75-88)UpdateFromEpisode(96-129)ResetEpisode(193-196)Serialize(226-237)Deserialize(239-256)SetParameters(283-297)IFullModel(299-321)ApplyGradients(336-339)SaveModel(341-346)LoadModel(348-353)src/ReinforcementLearning/Agents/MonteCarlo/OffPolicyMonteCarloAgent.cs (16)
T(85-89)Vector(40-71)Vector(204-207)Vector(262-285)Vector(321-332)EnsureStateExists(151-166)StoreExperience(73-83)UpdateFromEpisode(91-129)ResetEpisode(199-202)Serialize(232-242)Deserialize(244-260)SetParameters(287-301)IFullModel(303-319)ApplyGradients(334-337)SaveModel(339-348)LoadModel(350-364)src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs (3)
ReinforcementLearningAgentBase(34-415)ReinforcementLearningAgentBase(90-111)Train(170-175)src/Models/Options/MonteCarloExploringStartsOptions.cs (1)
MonteCarloExploringStartsOptions(11-15)
src/Models/Options/ExpectedSARSAOptions.cs (1)
src/ReinforcementLearning/Agents/ExpectedSARSA/ExpectedSARSAAgent.cs (2)
T(112-141)T(143-146)
src/ReinforcementLearning/Agents/IQL/IQLAgent.cs (2)
src/Models/Options/IQLOptions.cs (2)
IQLOptions(31-64)IQLOptions(54-63)src/Helpers/MathHelper.cs (2)
INumericOperations(36-66)MathHelper(19-992)
src/ReinforcementLearning/Agents/DuelingDQN/DuelingDQNAgent.cs (3)
src/ReinforcementLearning/Agents/DoubleDQN/DoubleDQNAgent.cs (9)
T(151-226)ReinforcementLearningOptions(75-96)CopyNetworkWeights(368-371)Vector(124-142)Vector(295-298)Vector(334-350)ArgMax(373-388)SetParameters(301-305)ApplyGradients(360-365)src/Models/Options/DuelingDQNOptions.cs (2)
DuelingDQNOptions(9-34)DuelingDQNOptions(28-33)src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs (2)
ReinforcementLearningOptions(421-487)Train(170-175)
🪛 GitHub Actions: Build
src/ReinforcementLearning/Agents/IQL/IQLAgent.cs
[error] 70-70: IQLOptions does not contain a definition for 'Validate' and no accessible extension method 'Validate' accepting a first argument of type 'IQLOptions' could be found (are you missing a using directive or an assembly reference?)
🪛 GitHub Actions: Quality Gates (.NET)
src/ReinforcementLearning/Agents/IQL/IQLAgent.cs
[error] 70-70: CS1061: 'IQLOptions' does not contain a definition for 'Validate' and no accessible extension method 'Validate' accepting a first argument of type 'IQLOptions' could be found.
🪛 GitHub Check: Build All Frameworks
src/ReinforcementLearning/Agents/IQL/IQLAgent.cs
[failure] 210-210:
No overload for method 'GetNormalRandom' takes 3 arguments
[failure] 70-70:
'IQLOptions' does not contain a definition for 'Validate' and no accessible extension method 'Validate' accepting a first argument of type 'IQLOptions' could be found (are you missing a using directive or an assembly reference?)
[failure] 210-210:
No overload for method 'GetNormalRandom' takes 3 arguments
[failure] 70-70:
'IQLOptions' does not contain a definition for 'Validate' and no accessible extension method 'Validate' accepting a first argument of type 'IQLOptions' could be found (are you missing a using directive or an assembly reference?)
[failure] 210-210:
No overload for method 'GetNormalRandom' takes 3 arguments
[failure] 70-70:
'IQLOptions' does not contain a definition for 'Validate' and no accessible extension method 'Validate' accepting a first argument of type 'IQLOptions' could be found (are you missing a using directive or an assembly reference?)
[failure] 210-210:
No overload for method 'GetNormalRandom' takes 3 arguments
[failure] 70-70:
'IQLOptions' does not contain a definition for 'Validate' and no accessible extension method 'Validate' accepting a first argument of type 'IQLOptions' could be found (are you missing a using directive or an assembly reference?)
🪛 GitHub Check: Publish Size Analysis
src/ReinforcementLearning/Agents/IQL/IQLAgent.cs
[failure] 210-210:
No overload for method 'GetNormalRandom' takes 3 arguments
[failure] 70-70:
'IQLOptions' does not contain a definition for 'Validate' and no accessible extension method 'Validate' accepting a first argument of type 'IQLOptions' could be found (are you missing a using directive or an assembly reference?)
🔇 Additional comments (54)
src/PredictionModelBuilder.cs (3)
67-67: LGTM!The RL environment field is properly declared following the builder's existing patterns.
740-756: Good fix — OptimizationResult and deployment config now correctly implemented.The code now uses object initializer syntax for
OptimizationResult(addressing the previous constructor parameter issue) and properly creates theDeploymentConfigurationfrom individual configs (addressing the missing deployment config issue).
1105-1130: LGTM!The
ConfigureEnvironmentmethod follows the builder pattern correctly and provides clear documentation with practical examples for RL training setup.src/Models/Options/ExpectedSARSAOptions.cs (1)
22-86: Validation implementation looks solid.The constructor and init accessor validation work together to prevent invalid configurations. The constructor enforces validation for the standard construction path, while the init accessors provide a secondary defense for object initializer syntax. This addresses all previous concerns about uninitialized properties and missing validation.
src/ReinforcementLearning/Agents/ExpectedSARSA/ExpectedSARSAAgent.cs (7)
39-69: Defensive validation properly implemented.The constructor now validates both
StateSizeandActionSizeto prevent runtime failures when properties are left at their default zero value. This addresses the concerns from previous reviews about uninitialized properties bypassing init-accessor validation.
71-110: Action selection and experience storage correctly implement Expected SARSA.The epsilon-greedy policy and TD update using the expected value under the current policy are properly implemented. The epsilon decay at line 109 ensures the agent gradually shifts from exploration to exploitation.
208-216: ModelType assignment correctly fixed.Now properly uses
ModelType.ReinforcementLearningenum value instead of a string literal, addressing the type mismatch flagged in previous reviews.
221-249: Serialization implementation is functional.The JSON-based serialization with proper null checks provides a working persistence mechanism. The use of
dynamicin deserialization (line 241) is pragmatic for this use case.
288-304: Clone correctly performs deep copy.The deep copy implementation (lines 295-300) properly creates independent Q-table instances for the clone by constructing new inner dictionaries. This fixes the shallow copy issue flagged in previous reviews.
306-309: ComputeGradients return type correctly fixed.Now returns
Vector<T>(viaGetParameters()) matching the method signature, addressing the type mismatch flagged in previous reviews.
313-338: Persistence methods properly implemented.Both
SaveModelandLoadModelnow have functional implementations with appropriate validation (null/whitespace checks, file existence), addressing theNotImplementedExceptionissue from previous reviews.src/ReinforcementLearning/Agents/DuelingDQN/DuelingDQNAgent.cs (7)
305-314: LGTM!ComputeGradientsnow properly signals unsupported operation.The implementation correctly throws
NotSupportedExceptionwith a clear message, which is the right approach since theDuelingNetworkstores gradients internally without exposing them. This is much better than the previous stub that returned a dummy vector, and prevents silent failures in external optimizer code.
317-337: LGTM!ApplyGradientscorrectly implements gradient descent updates.The implementation properly uses
Vectorindexing (gradients[i], line 328), bounds-checks the gradient vector, and correctly applies the gradient descent rule:newParams[i] = currentParams[i] - learningRate * gradValue(lines 329-330). This addresses the previous review comments about incorrect Matrix indexing.
492-495: LGTM!Predictmethod addition resolves interface compatibility.The
Predictmethod properly wrapsForwardand is well-documented. This addresses previous review comments about missingPredictmethod calls throughout the agent (lines 111, 150, 154) that would have caused runtime exceptions.
497-553: LGTM!Backwardnow implements complete backpropagation through the dueling architecture.The implementation correctly computes gradients for the dueling architecture:
- Value gradient as the sum of all Q-gradients (lines 505-509)
- Centered advantage gradients accounting for mean subtraction (lines 512-517)
- Proper backpropagation through all three streams (advantage, value, shared)
- Correct gradient combination from both streams for shared layers (lines 541-545)
This fully addresses the previous critical review comment about incomplete backpropagation.
555-574: LGTM!UpdateWeightsnow properly updates all network parameters.The implementation correctly invokes
UpdateParameters(learningRate)on all layers across the three streams (shared, value, advantage). This fully addresses the previous critical review comment about stub implementation preventing learning.
576-668: LGTM!GetFlattenedParametersandSetFlattenedParametersare fully implemented.Both methods are now complete:
GetFlattenedParameterscorrectly collects parameters from all three streams (shared, value, advantage) in a consistent orderSetFlattenedParametersproperly distributes parameters with offset tracking and includes validation to ensure parameter count matches (lines 664-667)This fully addresses the previous critical review comments about stub implementations breaking serialization and cloning.
670-764: LGTM! Network serialization is fully implemented with proper validation.Both
SerializeandDeserializeare complete:
Serializepersists architecture metadata and all parametersDeserializeincludes architecture validation (lines 730-736) that throwsInvalidOperationExceptionon mismatch- Proper null checking and error handling
This fully addresses the previous critical review comments about stub implementations breaking model persistence.
src/ReinforcementLearning/Agents/MonteCarlo/FirstVisitMonteCarloAgent.cs (5)
42-56: Constructor properly uses base Random instance.The constructor correctly initializes all required structures and now uses the inherited
Randominstance from the base class (line 55), preserving seeded behavior for reproducibility.
252-282: SetParameters correctly preserves state keys.The implementation now properly captures state keys before any modifications (line 255) and updates Q-values while preserving the table structure. The past issue where
Clear()was called before capturing keys has been resolved.
284-306: Clone properly deep-copies all state.The implementation now correctly deep-copies both the Q-table (lines 289-292) and the returns history (lines 295-302), ensuring clones are fully isolated from the source agent. The past issue with shared inner dictionaries and missing return history has been resolved.
202-232: Serialization now fully implemented.Both
Serialize()andDeserialize()now have complete implementations using JSON serialization, including proper null validation and safe restoration of state. The past issue where these wereNotImplementedExceptionstubs has been resolved.
58-90: Core action selection and experience storage logic is correct.The epsilon-greedy action selection (lines 58-75) and episode management in
StoreExperience(lines 77-90) are properly implemented. The agent correctly:
- Explores with probability ε and exploits with probability 1-ε
- Accumulates episode data
- Triggers updates and epsilon decay on episode completion
src/ReinforcementLearning/Agents/DoubleDQN/DoubleDQNAgent.cs (4)
58-72: LGTM: Constructor properly initializes networks and replay buffer.The null-handling fix (via
CreateBaseOptions) correctly validates options before base construction. Network initialization, weight synchronization, and registration are all correct.
75-96: LGTM: Null validation now occurs before base constructor call.The static helper pattern cleanly addresses the previous null-handling issue by validating
optionsbefore constructing base options.
295-298: LGTM: GetParameters correctly returns network parameters.The previous undefined variable issue has been fixed.
301-305: LGTM: SetParameters correctly synchronizes target network.The fix ensures both online and target networks remain consistent after external parameter injection.
src/ReinforcementLearning/Agents/IQL/IQLAgent.cs (2)
255-296: Value function update looks correct.The expectile regression implementation properly computes the asymmetric weight and applies the correct gradient
-2 * weight * (q - v)for backpropagation.
317-376: Q-function updates are correct.Both Q-networks properly apply the MSE gradient
-2 * (target - prediction)in their backward passes.src/ReinforcementLearning/Agents/MADDPG/MADDPGAgent.cs (6)
56-102: LGTM! Constructor properly validates configuration and initializes networks.The constructor now correctly:
- Validates options early (line 60) to catch bad configurations
- Uses the configured
ActorLearningRatefor the default optimizer (line 64)- Initializes all agent networks including target networks
104-131: LGTM! Actor network properly configured with per-layer hidden sizes.The network creation now correctly:
- Iterates through
ActorHiddenLayersto create properly sized hidden layers- Uses Xavier/Glorot initialization via
DenseLayerconstructor- Applies Tanh activation for continuous action output
455-480: LGTM! Target actions computed correctly.This method properly:
- Decomposes the joint state into per-agent states
- Computes actions using target actor networks (line 472)
- Concatenates them into a joint action vector
This correctly implements the MADDPG target computation using target actors.
626-635: Clone correctly copies main network weights.The implementation now properly transfers trained parameters (lines 631-632) using
GetParametersandSetParameters.Note: The cloned agent inherits the target network synchronization issue from
SetParameters(see comment on lines 558-616), but that's not a defect inCloneitself.
658-663: LGTM!ApplyGradientscorrectly throwsNotSupportedException.The method now properly indicates that external gradient application is not supported for MADDPG, with a clear error message directing users to the agent's internal
Train()loop.
535-556: LGTM! Serialization limitation properly documented.The serialization methods now correctly:
- Throw
NotSupportedExceptionwith clear error messages- Include XML documentation explaining the limitation (lines 522-534, 540-552, 669-675, 682-691)
- Direct users to
GetParameters()/SetParameters()as the alternativeThis is an appropriate way to handle a design limitation.
Also applies to: 676-695
src/ReinforcementLearning/Agents/MonteCarlo/MonteCarloExploringStartsAgent.cs (13)
26-40: Constructor correctly uses baseRandomfor seeded behavior.The constructor now properly assigns the base class
Randominstance to_random, preserving any seed configured viaoptions.Seedfor reproducible behavior.
42-73: LGTM! Exploring starts logic is correct.The method properly implements exploring starts by selecting a random action on the first step of an episode during training, then following the greedy policy based on Q-values for subsequent steps.
75-86: LGTM! Episode storage and reset logic is correct.The method correctly stores experience tuples and properly resets
_isFirstActionwhen an episode completes, ensuring the exploring start occurs on the first action of the next episode.
88-127: LGTM! First-visit Monte Carlo update is correctly implemented.The training logic correctly processes episodes in reverse, uses first-visit tracking via the
visitedHashSet, accumulates returns, and updates Q-values by averaging returns. This matches the expected pattern for Monte Carlo Exploring Starts.
129-179: LGTM! Helper methods are correctly implemented.The utility methods (
EnsureStateExists,GetStateKey,ComputeAverage,ArgMax) follow the same patterns as other Monte Carlo agents and correctly use theNumOpsabstraction for type-generic operations.
181-194: LGTM! Metrics and episode reset are correctly implemented.
GetMetricsprovides useful diagnostic information, andResetEpisodeproperly clears the episode buffer and resets the exploring start flag.
213-223: LGTM! Metadata and properties are correctly implemented.The model metadata, parameter count, and feature count properties provide accurate information about the agent's state and configuration.
225-269: LGTM! Serialization now preserves_isFirstActionstate.The serialization correctly includes
_isFirstActionin the state, and deserialization safely handles backward compatibility by defaulting totrueif the field is missing, preserving the exploring-starts guarantee.
271-294: LGTM! GetParameters now returns Vector correctly.The method properly flattens the Q-table parameters into a
Vector<T>and handles the empty table case by returning a vector with a single zero element.
312-335: LGTM! Clone now preserves complete agent state.The method correctly deep-copies
_qTableand_returnsto avoid shared mutable state, and now also preserves_isFirstActionto maintain mid-episode state consistency.
337-353: LGTM! Gradient methods correctly reflect tabular nature.
ComputeGradientsprovides a loss-based gradient for compatibility with the interface, whileApplyGradientscorrectly does nothing since Monte Carlo methods update Q-values directly from episode returns rather than gradient descent.
355-380: LGTM! Model persistence now works correctly.
SaveModelandLoadModelproperly call the now-implementedSerializeandDeserializemethods, with appropriate filepath validation and error handling.
202-211:overridekeywords are incorrect — these methods do not override base members.
PredictAsyncandTrainAsyncare new methods specific to this class, not part of any base class or interface contract. Addingoverridewould cause a compiler error since there are no virtual or abstract methods with these signatures to override inReinforcementLearningAgentBase<T>or its inherited interfaces.The base class provides
Predict(Vector<T>)(virtual) andTrain()(abstract), but neither async variants nor anyPredictAsync/TrainAsyncmethods exist in the inheritance chain.Likely an incorrect or invalid review comment.
src/ReinforcementLearning/Agents/EligibilityTraces/SARSALambdaAgent.cs (6)
27-51: LGTM: epsilon-greedy selection correctly implemented.The action selection properly uses the base class
Randomfor reproducibility and correctly applies theArgMaxhelper to find the greedy action. Previous concerns about code duplication have been addressed.
96-126: LGTM: helper methods are well-implemented.The helper methods correctly handle state initialization, key generation, and argmax computation. The use of "F4" formatting in
GetStateKeyprovides stable keys while avoiding floating-point precision issues.
183-203: LGTM: Clone now correctly preserves all learned state.The clone method now deep-copies the Q-table, eligibility traces, epsilon, and episode state (
_lastState,_lastAction). This addresses a previous concern about the cloned agent not preserving learned state.
136-166: LGTM: Serialization now properly implemented.The
SerializeandDeserializemethods now use JSON serialization with proper validation. This addresses a previous critical issue where these methods threwNotImplementedException, makingSaveandLoadnon-functional.
206-231: LGTM: Save and Load methods properly implemented.Both
SaveModelandLoadModelinclude proper validation (null/whitespace checks, file existence check) and correctly delegate to theSerializeandDeserializemethods. Good error handling throughout.
131-132: The original review comment is incorrect.Based on verification,
PredictAsyncandTrainAsyncinSARSALambdaAgent.cslines 131-132 are not overriding anything from the base class. The base classReinforcementLearningAgentBase<T>containsPredict(Vector<T> input)(synchronous, not async) andTrain()(non-generic, not async), but does not declarePredictAsyncorTrainAsync. These methods are new async convenience wrappers added at the derived class level, not overrides of base class members. Theoverridekeyword is therefore inappropriate and would be incorrect.Likely an incorrect or invalid review comment.
src/ReinforcementLearning/Agents/MonteCarlo/FirstVisitMonteCarloAgent.cs
Show resolved
Hide resolved
src/ReinforcementLearning/Agents/MonteCarlo/FirstVisitMonteCarloAgent.cs
Show resolved
Hide resolved
Add optional Random parameter to MathHelper.GetNormalRandom() to enable reproducible Gaussian noise generation when callers pass seeded Random instances. Addresses PR #481 review comment about MathHelper.GetNormalRandom creating new Random on each call.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
src/Helpers/MathHelper.cs (1)
468-472: Implementation is correct.The null-coalescing operator properly handles the optional parameter, and renaming the local variable to
rngavoids shadowing the parameter. The Box-Muller transform implementation remains correct, properly using1.0 - NextDouble()to avoid edge cases withlog(0).Note on thread safety:
Randominstances are not thread-safe. If callers need to generate normal random numbers concurrently, they should either use thread-localRandominstances or considerSystem.Randomalternatives likeRandomNumberGeneratorfor cryptographic scenarios. This is a general concern withRandomusage, not specific to this change.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/Helpers/MathHelper.cs(2 hunks)
🔇 Additional comments (1)
src/Helpers/MathHelper.cs (1)
446-465: LGTM! Excellent addition for reproducibility.The documentation clearly explains the new optional
randomparameter and its implications for reproducibility. The signature change is backward-compatible, and the guidance about passing a seededRandominstance for deterministic results is valuable, especially for RL experimentation and testing scenarios.
Add Validate() method to IQLOptions<T> that validates StateSize, ActionSize, BatchSize, and BufferSize are > 0, following the same pattern as MADDPGOptions to ensure fail-fast behavior. Addresses PR #481 review comment about missing options validation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
src/Models/Options/IQLOptions.cs (2)
1-4: Add missing using forMathHelper(compile-time error).
MathHelper.GetNumericOperations<T>()is used in the constructor, but there is nousingfor the helpers namespace. As noted in the earlier review, project implicit/global usings will not importAiDotNet.Helpers, so this file will not compile as-is.Apply:
-using AiDotNet.LossFunctions; +using AiDotNet.LossFunctions; +using AiDotNet.Helpers;
65-78: ExtendValidate()to cover more invariants (expectile, batch/buffer relationship, null checks).You’ve added essential checks for
StateSize,ActionSize,BatchSize, andBufferSize, which is a good improvement. However, several critical invariants are still unchecked, so invalid configurations can slip through and cause runtime failures or nonsensical training behavior (this was raised in the earlier review).At minimum, it would be safer to:
- Enforce
BatchSize <= BufferSize.- Enforce
Expectile∈ (0, 1).- Ensure
QLossFunctionand the hidden-layer lists are non-null and non-empty.For example:
public void Validate() { if (StateSize <= 0) throw new ArgumentException("StateSize must be greater than 0", nameof(StateSize)); if (ActionSize <= 0) throw new ArgumentException("ActionSize must be greater than 0", nameof(ActionSize)); if (BatchSize <= 0) throw new ArgumentException("BatchSize must be greater than 0", nameof(BatchSize)); if (BufferSize <= 0) throw new ArgumentException("BufferSize must be greater than 0", nameof(BufferSize)); + + if (BatchSize > BufferSize) + throw new ArgumentException("BatchSize must be less than or equal to BufferSize", nameof(BatchSize)); + + if (Expectile <= 0.0 || Expectile >= 1.0) + throw new ArgumentException("Expectile must be in the (0, 1) range", nameof(Expectile)); + + if (QLossFunction is null) + throw new InvalidOperationException("QLossFunction cannot be null."); + + if (PolicyHiddenLayers is null || PolicyHiddenLayers.Count == 0) + throw new ArgumentException("PolicyHiddenLayers must contain at least one layer size.", nameof(PolicyHiddenLayers)); + + if (QHiddenLayers is null || QHiddenLayers.Count == 0) + throw new ArgumentException("QHiddenLayers must contain at least one layer size.", nameof(QHiddenLayers)); + + if (ValueHiddenLayers is null || ValueHiddenLayers.Count == 0) + throw new ArgumentException("ValueHiddenLayers must contain at least one layer size.", nameof(ValueHiddenLayers)); }You can go further (e.g., validate discount and tau ranges or learning-rate positivity) if that aligns with how other option classes validate their hyperparameters.
🧹 Nitpick comments (1)
src/Models/Options/IQLOptions.cs (1)
5-30: Docs are clear; consider XML-doc-friendly formatting for lists.The high-level explanation of IQL (including “for beginners” and advantages) is excellent. If you rely on generated XML docs/IntelliSense rendering, consider replacing the Markdown-style bullets (
- **...**) with<list>/<item>or plain text so the formatting is consistent with other XML docs in the codebase.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/Models/Options/IQLOptions.cs(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/Models/Options/IQLOptions.cs (2)
src/Helpers/NeuralNetworkHelper.cs (1)
ILossFunction(49-76)src/Helpers/MathHelper.cs (1)
MathHelper(19-997)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Build All Frameworks
🔇 Additional comments (2)
src/Models/Options/IQLOptions.cs (2)
31-53: Property surface and defaults look reasonable for IQL configuration.The exposed properties and defaults (learning rates, discount, tau, temperature, buffer sizes, hidden layers, default MSE loss) are coherent and align with typical IQL setups; the options class is easy to configure and extend.
54-63: Constructor defaults viaMathHelperare appropriate.Using
MathHelper.GetNumericOperations<T>()to set learning rates, discount, tau, and temperature from doubles is consistent with the rest of the numeric-ops infrastructure and keeps defaults centralized and type-safe.
…ledqn The previous code incorrectly used output-space gradients (e.g., 4 elements for action space) to index into the parameter vector (thousands of network weights), causing IndexOutOfRangeException. Now correctly: 1. Compute output gradients from loss function 2. Backpropagate through network 3. Extract parameter gradients using network.getgradients() 4. Apply learning rate to parameter gradients for weight updates This fixes the gradient indexing bug in both DQNAgent.cs:217-228 and DoubleDQNAgent.cs:200-212 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (4)
src/ReinforcementLearning/Agents/DQN/DQNAgent.cs (2)
298-320: Deserialize discards critical hyperparameters.Lines 304-307 read
stateSize,actionSize,learningRate, anddiscountFactorbut don't use them. A deserialized agent will ignore its saved learning configuration, causing training inconsistencies or failures when resuming from checkpoints.Apply this diff to validate and restore the serialized configuration:
- var stateSize = reader.ReadInt32(); - var actionSize = reader.ReadInt32(); - var learningRate = reader.ReadDouble(); - var discountFactor = reader.ReadDouble(); + var stateSize = reader.ReadInt32(); + if (stateSize != _dqnOptions.StateSize) + throw new InvalidOperationException($"Serialized state size {stateSize} does not match configured {_dqnOptions.StateSize}"); + + var actionSize = reader.ReadInt32(); + if (actionSize != _dqnOptions.ActionSize) + throw new InvalidOperationException($"Serialized action size {actionSize} does not match configured {_dqnOptions.ActionSize}"); + + var learningRate = reader.ReadDouble(); + LearningRate = NumOps.FromDouble(learningRate); + + var discountFactor = reader.ReadDouble(); + DiscountFactor = NumOps.FromDouble(discountFactor); _epsilon = reader.ReadDouble();
363-379: ComputeGradients returns output-space gradients instead of parameter gradients.After calling
Backpropagate()at Line 376, the method returnsgradient(Line 378), which is the loss derivative w.r.t. outputs (∂Loss/∂output, size ≈ ActionSize). This does not match the method's apparent contract of computing parameter gradients for external optimizers.The correct pattern is to extract accumulated parameter gradients from the network layers after backpropagation, as done in the
Train()method at Lines 222-223.Apply this diff to return actual parameter gradients:
var gradientTensor = Tensor<T>.FromVector(gradient); _qNetwork.Backpropagate(gradientTensor); - return gradient; + // Extract parameter gradients accumulated during backpropagation + return _qNetwork.GetGradients();src/ReinforcementLearning/Agents/DoubleDQN/DoubleDQNAgent.cs (2)
275-294: Deserialize discards critical hyperparameters.Lines 280-283 read
stateSize,actionSize,learningRate, anddiscountFactorbut discard them without validation or restoration. This breaks checkpoint resumption since the deserialized agent won't retain its training configuration.Apply this diff:
- reader.ReadInt32(); // stateSize - reader.ReadInt32(); // actionSize - reader.ReadDouble(); // learningRate - reader.ReadDouble(); // discountFactor + var stateSize = reader.ReadInt32(); + if (stateSize != _options.StateSize) + throw new InvalidOperationException($"Serialized state size {stateSize} does not match configured {_options.StateSize}"); + + var actionSize = reader.ReadInt32(); + if (actionSize != _options.ActionSize) + throw new InvalidOperationException($"Serialized action size {actionSize} does not match configured {_options.ActionSize}"); + + var learningRate = reader.ReadDouble(); + LearningRate = NumOps.FromDouble(learningRate); + + var discountFactor = reader.ReadDouble(); + DiscountFactor = NumOps.FromDouble(discountFactor); _epsilon = reader.ReadDouble();
336-352: ComputeGradients returns output-space gradients instead of parameter gradients.After
Backpropagate()(Line 349), the method returnsgradient(Line 351), which is the loss derivative w.r.t. outputs (∂Loss/∂output, size ≈ ActionSize), not the accumulated parameter gradients (∂Loss/∂θ).Since
ApplyGradientsthrowsNotSupportedException, external gradient-based optimization isn't supported. However, the method name and signature imply it should return parameter gradients. Consider either:
- Returning actual parameter gradients via
_qNetwork.GetGradients(), or- Also throwing
NotSupportedExceptionwith clear documentation that external gradient APIs are not supported.If parameter gradients should be exposed, apply this diff:
var gradientTensor = Tensor<T>.FromVector(gradient); _qNetwork.Backpropagate(gradientTensor); - return gradient; + return _qNetwork.GetGradients();
🧹 Nitpick comments (1)
src/ReinforcementLearning/Agents/DoubleDQN/DoubleDQNAgent.cs (1)
362-367: ApplyGradients correctly declares unsupported operation.Throwing
NotSupportedExceptionis a valid design choice for agents that manage gradient updates internally. However,ComputeGradients(Lines 336-352) is still callable and returns output-space gradients, creating an inconsistent API where one method works but its natural pair doesn't.Consider also throwing
NotSupportedExceptioninComputeGradientsfor API consistency, or updatingComputeGradientsto return parameter gradients and making both methods fully functional. Document the design intent clearly in XML comments.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/ReinforcementLearning/Agents/DQN/DQNAgent.cs(1 hunks)src/ReinforcementLearning/Agents/DoubleDQN/DoubleDQNAgent.cs(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/ReinforcementLearning/Agents/DoubleDQN/DoubleDQNAgent.cs (3)
src/ReinforcementLearning/Agents/DQN/DQNAgent.cs (9)
T(165-248)T(421-434)NeuralNetwork(103-131)CopyNetworkWeights(398-402)ReinforcementLearningOptions(80-101)Vector(134-155)Vector(323-326)Vector(363-379)ArgMax(404-419)src/Models/Options/DoubleDQNOptions.cs (2)
DoubleDQNOptions(9-32)DoubleDQNOptions(26-31)src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs (2)
ReinforcementLearningOptions(421-487)Train(170-175)
src/ReinforcementLearning/Agents/DQN/DQNAgent.cs (4)
src/ReinforcementLearning/Agents/DoubleDQN/DoubleDQNAgent.cs (14)
T(151-228)NeuralNetwork(98-121)CopyNetworkWeights(370-373)ReinforcementLearningOptions(75-96)Vector(124-142)Vector(297-300)Vector(336-352)ArgMax(375-390)StoreExperience(145-148)Serialize(251-272)Deserialize(275-294)SetParameters(303-307)IFullModel(310-333)ApplyGradients(362-367)src/Models/Options/DQNOptions.cs (2)
DQNOptions(9-32)DQNOptions(26-31)src/ReinforcementLearning/ReplayBuffers/UniformReplayBuffer.cs (5)
UniformReplayBuffer(18-96)UniformReplayBuffer(35-44)Add(47-59)List(62-82)CanSample(85-88)src/NeuralNetworks/NeuralNetwork.cs (3)
Tensor(179-197)Train(228-262)UpdateParameters(141-154)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Build All Frameworks
🔇 Additional comments (6)
src/ReinforcementLearning/Agents/DQN/DQNAgent.cs (3)
57-77: Constructor implementation looks solid.The use of
CreateBaseOptionshelper properly validates the options parameter before the base constructor call, avoiding the null reference issue flagged in previous reviews. Network initialization and synchronization logic is correct.
329-334: SetParameters correctly synchronizes both networks.The target network sync at Line 333 ensures both networks remain consistent after external parameter injection, addressing the issue from previous reviews.
216-232: Training loop correctly applies parameter gradients.The code properly extracts parameter gradients via
_qNetwork.GetGradients()at Line 222 after backpropagation, then applies them element-wise to the network parameters. This addresses the dimension mismatch issue flagged in previous reviews.src/ReinforcementLearning/Agents/DoubleDQN/DoubleDQNAgent.cs (3)
58-72: Constructor properly handles initialization.The
CreateBaseOptionshelper ensures null safety before the base constructor call, and network initialization follows the correct pattern with proper synchronization.
303-307: SetParameters correctly synchronizes both networks.The target network sync at Line 306 ensures consistency after parameter updates, properly addressing the previous review concern.
164-228: Training loop correctly implements Double DQN algorithm.The code properly implements the Double DQN innovation: the online network selects the best action (Line 178), while the target network evaluates it (Line 183). Parameter gradient extraction and application (Lines 205-212) correctly use
GetGradients()to avoid the dimension mismatch issue.
…nize after loading
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (3)
src/ReinforcementLearning/Agents/MonteCarlo/FirstVisitMonteCarloAgent.cs (1)
150-158: State key precision remains inconsistent with other MC agents.Line 155 still uses
ToString("F4")(4 decimal places) whileEveryVisitMonteCarloAgentand other Monte Carlo agents in the codebase useToString("F8")(8 decimal places). Lower precision can cause state aliasing where numerically distinct states are incorrectly treated as identical.Apply this diff for consistency:
- parts[i] = NumOps.ToDouble(state[i]).ToString("F4"); + parts[i] = NumOps.ToDouble(state[i]).ToString("F8");src/ReinforcementLearning/Agents/MADDPG/MADDPGAgent.cs (2)
305-326: Critical: Remove incorrect manual gradient application after backpropagation.Lines 320-324 apply a uniform scalar gradient to all network parameters, overwriting the per-parameter gradients computed by
Backpropagateon line 316. Thegradientsvector has length 1 (the derivative of the scalar Q-value), sogradients[Math.Min(i, gradients.Length - 1)]evaluates togradients[0]for every parameter—this is not how neural network training works.The
Backpropagatecall already computes correct per-parameter gradients internally. Remove the manual parameter update loop:if (_criticNetworks[agentId] is NeuralNetwork<T> criticNetwork) { criticNetwork.Backpropagate(gradientsTensor); - - // Apply gradient descent updates to critic network parameters - var parameters = criticNetwork.GetParameters(); - for (int i = 0; i < parameters.Length && i < gradients.Length; i++) - { - var update = NumOps.Multiply(_options.CriticLearningRate, gradients[Math.Min(i, gradients.Length - 1)]); - parameters[i] = NumOps.Subtract(parameters[i], update); - } - criticNetwork.UpdateParameters(parameters); + // Backpropagate has updated internal gradients; now apply them via optimizer + _optimizer.UpdateParameters(_criticNetworks[agentId].Layers.Cast<ILayer<T>>().ToList()); }Note: Verify your optimizer applies the learning rate correctly, or implement a direct weight update using the network's internal gradient state.
734-745: ThrowNotSupportedExceptionfor supervised gradient computation.
ComputeGradientsimplements a supervised learning pattern (input vs. target), which doesn't apply to MADDPG's RL training. This method likely exists forIFullModelinterface compliance but will produce incorrect gradients if invoked.For consistency with
ApplyGradients(lines 755-760), throwNotSupportedException:public override Vector<T> ComputeGradients( Vector<T> input, Vector<T> target, ILossFunction<T>? lossFunction = null) { - var prediction = Predict(input); - var usedLossFunction = lossFunction ?? LossFunction; - var loss = usedLossFunction.CalculateLoss(prediction, target); - - var gradient = usedLossFunction.CalculateDerivative(prediction, target); - return gradient; + throw new NotSupportedException( + "ComputeGradients is not applicable to MADDPG. Use Train() for RL updates."); }
🧹 Nitpick comments (2)
src/ReinforcementLearning/Agents/MonteCarlo/FirstVisitMonteCarloAgent.cs (1)
160-170: Consider more defensive validation in GetActionIndex.The method silently returns 0 if no positive element is found, which could mask malformed action vectors. Other agents in the codebase (e.g.,
EveryVisitMonteCarloAgent) throwArgumentExceptionfor invalid vectors, improving debuggability.Consider adding validation:
private int GetActionIndex(Vector<T> action) { + if (action is null || action.Length == 0) + { + throw new ArgumentException("Action vector cannot be null or empty", nameof(action)); + } + for (int i = 0; i < action.Length; i++) { if (NumOps.GreaterThan(action[i], NumOps.Zero)) { return i; } } - return 0; + + throw new ArgumentException( + "Invalid action vector: expected one-hot encoded vector with exactly one positive element.", + nameof(action)); }src/ReinforcementLearning/Agents/MADDPG/MADDPGAgent.cs (1)
674-712: Target network parameter loading is immediately overwritten.Lines 674-698 load target network parameters from the serialized vector, but lines 700-712 immediately overwrite them by copying from the main networks. This makes the target network parameter serialization/deserialization redundant.
Option 1 (recommended): Don't serialize target network parameters since they'll be synchronized anyway:
public override Vector<T> GetParameters() { var allParams = new List<T>(); foreach (var network in _actorNetworks) { var netParams = network.GetParameters(); for (int i = 0; i < netParams.Length; i++) { allParams.Add(netParams[i]); } } foreach (var network in _criticNetworks) { var netParams = network.GetParameters(); for (int i = 0; i < netParams.Length; i++) { allParams.Add(netParams[i]); } } - - foreach (var network in _targetActorNetworks) - { - // ... (remove target network serialization) - } - - foreach (var network in _targetCriticNetworks) - { - // ... (remove target critic serialization) - } return new Vector<T>(allParams.ToArray()); }And update
SetParametersto only load main networks, then sync targets at the end (keeping lines 700-712).Option 2: Keep target parameters as loaded (remove synchronization at lines 700-712) if you want to preserve the target network state exactly as saved.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/ReinforcementLearning/Agents/MADDPG/MADDPGAgent.cs(1 hunks)src/ReinforcementLearning/Agents/MonteCarlo/FirstVisitMonteCarloAgent.cs(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/ReinforcementLearning/Agents/MonteCarlo/FirstVisitMonteCarloAgent.cs (6)
src/ReinforcementLearning/Agents/MonteCarlo/MonteCarloExploringStartsAgent.cs (14)
T(88-92)T(148-162)Dictionary(181-188)Vector(42-73)Vector(196-200)Vector(271-294)Vector(337-348)StoreExperience(75-86)UpdateFromEpisode(94-127)EnsureStateExists(129-141)Serialize(225-236)Deserialize(238-269)SetParameters(296-310)IFullModel(312-335)src/ReinforcementLearning/Agents/ExpectedSARSA/ExpectedSARSAAgent.cs (14)
T(112-141)T(143-146)Vector(71-88)Vector(251-266)Vector(306-309)VectorToStateKey(148-156)GetBestAction(191-206)StoreExperience(90-110)GetActionIndex(158-177)EnsureStateExists(179-189)Serialize(221-231)Deserialize(233-249)SetParameters(268-286)IFullModel(288-304)src/ReinforcementLearning/Agents/MonteCarlo/OnPolicyMonteCarloAgent.cs (14)
T(90-94)T(150-164)Dictionary(183-191)Vector(42-73)Vector(198-201)Vector(258-281)Vector(323-334)StoreExperience(75-88)UpdateFromEpisode(96-129)EnsureStateExists(131-143)Serialize(226-237)Deserialize(239-256)SetParameters(283-297)IFullModel(299-321)src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs (3)
ReinforcementLearningAgentBase(34-415)ReinforcementLearningAgentBase(90-111)Train(170-175)src/Models/Options/MonteCarloOptions.cs (1)
MonteCarloOptions(9-13)src/ReinforcementLearning/Agents/MonteCarlo/EveryVisitMonteCarloAgent.cs (5)
VectorToStateKey(106-114)GetBestAction(153-167)GetActionIndex(122-139)UpdateFromEpisode(75-97)EnsureStateExists(141-151)
src/ReinforcementLearning/Agents/MADDPG/MADDPGAgent.cs (6)
src/ReinforcementLearning/Agents/DeepReinforcementLearningAgentBase.cs (2)
DeepReinforcementLearningAgentBase(33-105)DeepReinforcementLearningAgentBase(65-69)src/Models/Options/MADDPGOptions.cs (3)
MADDPGOptions(33-73)MADDPGOptions(53-59)Validate(64-72)src/ReinforcementLearning/Agents/TD3/TD3Agent.cs (4)
CopyNetworkWeights(389-393)NeuralNetwork(104-128)NeuralNetwork(130-155)UpdateActor(315-361)src/NeuralNetworks/FeedForwardNeuralNetwork.cs (5)
Tensor(131-146)Tensor(166-178)Tensor(199-208)Train(258-292)UpdateParameters(226-239)src/NeuralNetworks/NeuralNetwork.cs (3)
Tensor(179-197)Train(228-262)UpdateParameters(141-154)src/Helpers/MathHelper.cs (1)
MathHelper(19-997)
🪛 GitHub Actions: Build
src/ReinforcementLearning/Agents/MADDPG/MADDPGAgent.cs
[error] 382-382: CS1061: 'NeuralNetwork' does not contain a definition for 'GetInputGradients' and no accessible extension method 'GetInputGradients' accepting a first argument of type 'NeuralNetwork' could be found (are you missing a using directive or an assembly reference?)
🪛 GitHub Actions: Quality Gates (.NET)
src/ReinforcementLearning/Agents/MADDPG/MADDPGAgent.cs
[error] 382-382: CS1061: 'NeuralNetwork' does not contain a definition for 'GetInputGradients' and no accessible extension method 'GetInputGradients' accepting a first argument of type 'NeuralNetwork' could be found (are you missing a using directive or an assembly reference?)
🪛 GitHub Check: Build All Frameworks
src/ReinforcementLearning/Agents/MADDPG/MADDPGAgent.cs
[failure] 382-382:
'NeuralNetwork' does not contain a definition for 'GetInputGradients' and no accessible extension method 'GetInputGradients' accepting a first argument of type 'NeuralNetwork' could be found (are you missing a using directive or an assembly reference?)
[failure] 382-382:
'NeuralNetwork' does not contain a definition for 'GetInputGradients' and no accessible extension method 'GetInputGradients' accepting a first argument of type 'NeuralNetwork' could be found (are you missing a using directive or an assembly reference?)
[failure] 382-382:
'NeuralNetwork' does not contain a definition for 'GetInputGradients' and no accessible extension method 'GetInputGradients' accepting a first argument of type 'NeuralNetwork' could be found (are you missing a using directive or an assembly reference?)
[failure] 382-382:
'NeuralNetwork' does not contain a definition for 'GetInputGradients' and no accessible extension method 'GetInputGradients' accepting a first argument of type 'NeuralNetwork' could be found (are you missing a using directive or an assembly reference?)
🪛 GitHub Check: Publish Size Analysis
src/ReinforcementLearning/Agents/MADDPG/MADDPGAgent.cs
[failure] 382-382:
'NeuralNetwork' does not contain a definition for 'GetInputGradients' and no accessible extension method 'GetInputGradients' accepting a first argument of type 'NeuralNetwork' could be found (are you missing a using directive or an assembly reference?)
🔇 Additional comments (4)
src/ReinforcementLearning/Agents/MonteCarlo/FirstVisitMonteCarloAgent.cs (3)
42-56: Constructor properly addressed past seeding concern.The constructor now correctly assigns
_random = Random(line 55), using the inherited base Random instance initialized with the seed from options. This preserves deterministic behavior for testing and reproducibility.
268-298: SetParameters correctly addressed the past critical issue.The implementation now properly captures state keys before any modifications (line 271) and updates Q-values in place without clearing the table. This resolves the previous blocker where
_qTable.Clear()was called before capturing keys, which erased all value estimates.
218-248: Serialization, cloning, and persistence methods properly implemented.The implementation correctly addresses all past concerns:
Clone (lines 300-322): Now deep-copies both the Q-table and the returns history (lines 311-317), ensuring clones are fully isolated with preserved return data for accurate first-visit averaging.
Serialize/Deserialize (lines 218-248): Replaced the previous
NotImplementedExceptionstubs with proper JSON-based serialization, including Q-table, returns history, epsilon, and options.SaveModel/LoadModel (lines 331-356): Now functional since they rely on the implemented Serialize/Deserialize methods.
Also applies to: 300-322, 331-356
src/ReinforcementLearning/Agents/MADDPG/MADDPGAgent.cs (1)
225-232: Reward averaging limits use to cooperative scenarios; confirm if per-agent tracking is needed.The implementation averages rewards across all agents (lines 225-232), preventing per-agent reward signals. While documented as a known limitation (lines 207-211), this design choice restricts the agent to cooperative scenarios only.
Verify whether this cooperative-only constraint is acceptable for your use case. If competitive or mixed-motive environments are required, implement per-agent reward storage by modifying
StoreMultiAgentExperience()to preserve individual agent rewards instead of averaging them.
…loagent to avoid state aliasing
…arios PROBLEM: MADDPG was averaging rewards across all agents before storing in replay buffer. This only works for fully cooperative scenarios where all agents share identical rewards. In competitive or mixed-motive environments, agents have distinct (often opposing) rewards that should not be averaged. SOLUTION: 1. Add SampleWithIndices() method to UniformReplayBuffer to return both experiences and their buffer indices 2. Store per-agent rewards in _perAgentRewards dictionary keyed by buffer index 3. Account for circular buffer behavior when storing rewards 4. Modify UpdateCritic to use agent-specific rewards instead of averaged rewards 5. Maintain backward compatibility by falling back to averaged rewards when per-agent rewards aren't available IMPACT: - Enables MADDPG to work in competitive and mixed-motive scenarios - Maintains backward compatibility with cooperative scenarios - Each agent's critic now receives the correct reward signal for that agent - No breaking changes to existing code FILES MODIFIED: - src/ReinforcementLearning/ReplayBuffers/UniformReplayBuffer.cs:296 - src/ReinforcementLearning/Agents/MADDPG/MADDPGAgent.cs:296 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
…gradients PROBLEM: The UpdateCritic method was incorrectly applying gradients to network parameters. Lines 362-366 applied a scalar output-space gradient uniformly to all parameters: - gradients vector had length 1 (derivative of scalar Q-value w.r.t. loss) - gradients[0] was applied to ALL network parameters - This is fundamentally wrong - each parameter needs its own gradient The Backpropagate() call on line 358 already computed the correct per-parameter gradients internally, but then the manual loop overwrote those updates with a uniform scalar applied to every parameter. SOLUTION: Follow the correct pattern from DQNAgent.cs (lines 221-231): 1. Call Backpropagate() to compute gradients through the network 2. Call GetGradients() to retrieve the per-parameter gradients (not output-space) 3. Apply learning rate to each parameter using its corresponding gradient 4. Update network parameters with the corrected values IMPACT: - Critic network now receives correct per-parameter gradients from backpropagation - Each weight/bias is updated based on its actual contribution to the loss - Fixes critical bug that would prevent MADDPG from learning effectively - Follows the same pattern as other working agents (DQN, DoubleDQN) Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
…orrect action gradient indexing PROBLEM: The UpdateActor method had two critical issues: 1. Called criticNetwork.GetInputGradients() which doesn't exist 2. Incorrect indexing when extracting action gradients from joint [state, action] input Lines 427: Attempted to call non-existent GetInputGradients() method Lines 434-441: Used wrong indexing formula combining actionStartIdx + jointActionIdx SOLUTION: 1. Capture the return value from Backpropagate(): - Backpropagate() returns Tensor<T> containing gradients w.r.t. network input - Convert to Vector<T> using .ToVector() - This gives us dQ/d(state,action) for the entire joint input 2. Fix action gradient indexing: - Critic input structure: [jointState, jointAction] concatenated - jointStateSize = total length of all agents' states - Action gradients start at index jointStateSize - This agent's actions: jointStateSize + (agentId * actionSize + i) - Removed incorrect formula using actionStartIdx + jointActionIdx CORRECT INDEXING: For critic input [s1, s2, ..., sN, a1, a2, ..., aN]: - States occupy indices [0, jointStateSize) - Actions occupy indices [jointStateSize, jointStateSize + jointActionSize) - Agent i's actions: [jointStateSize + i*actionSize, jointStateSize + (i+1)*actionSize) IMPACT: - Actor network can now properly receive gradients from critic - Deterministic policy gradient computation is now mathematically correct - Enables MADDPG to learn actor policies that maximize Q-values Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
…e variable PROBLEM: Build error CS0136 - variable 'i' was declared in both outer loop (line 300) iterating over batch experiences and inner loop (line 364) iterating over network parameters, causing a variable shadowing conflict that prevented compilation. SOLUTION: Renamed inner loop variable from 'i' to 'paramIdx' to clearly indicate it's iterating over parameter indices and avoid scope conflict with outer 'i' variable. LOCATION: src/ReinforcementLearning/Agents/MADDPG/MADDPGAgent.cs:364 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
♻️ Duplicate comments (2)
src/ReinforcementLearning/Agents/DQN/DQNAgent.cs (1)
363-379: ComputeGradients returns output-space gradients instead of parameter-space gradients.After calling
_qNetwork.Backpropagate(gradientTensor)at Line 376, the method returnsgradient(Line 378), which is the output-space gradient computed byloss.CalculateDerivative(output, target)at Line 373. This breaks external optimizer hooks that expect parameter-space gradients (derivatives with respect to all network weights and biases, not network outputs).The correct approach is to retrieve the parameter-space gradients accumulated during backpropagation and return those.
Apply this diff to return parameter-space gradients:
var gradientTensor = Tensor<T>.FromVector(gradient); _qNetwork.Backpropagate(gradientTensor); - return gradient; + // Return parameter-space gradients accumulated during backpropagation + return _qNetwork.GetGradients();Note: This assumes
_qNetwork.GetGradients()exists and returns a flattened vector of all parameter gradients. If this method doesn't exist, it must be implemented or an alternative approach used (see the verification comment for theTrainmethod).src/ReinforcementLearning/Agents/MADDPG/MADDPGAgent.cs (1)
779-790:ComputeGradientsstill implements a supervised pattern, whileApplyGradientsthrows
ComputeGradientsreturnslossFunction.CalculateDerivative(prediction, target), i.e., a gradient w.r.t. model outputs, not parameters, and ignores the RL-specific training flow. At the same time,ApplyGradientsthrowsNotSupportedException, so external optimizers cannot safely use this pair for MADDPG.Either:
- Make
ComputeGradientsmirrorApplyGradientsand throwNotSupportedExceptionwith a message directing users toTrain(), or- Clearly document that
ComputeGradientsis only for optional supervised fine-tuning of the policy/critic and is not used in standard MADDPG training.
🧹 Nitpick comments (5)
src/ReinforcementLearning/Agents/MonteCarlo/FirstVisitMonteCarloAgent.cs (2)
160-170: Consider defensive validation for invalid action vectors.The method returns 0 when no non-zero element is found. While this is reasonable given the one-hot encoding contract, an invalid action vector (all zeros) would be silently handled, potentially masking upstream bugs.
Optionally add validation:
private int GetActionIndex(Vector<T> action) { for (int i = 0; i < action.Length; i++) { if (NumOps.GreaterThan(action[i], NumOps.Zero)) { return i; } } + throw new ArgumentException("Action vector must have at least one non-zero element", nameof(action)); - return 0; }
145-148: Consider documenting stub methods for tabular learning.The stub implementations of
Train(),ComputeGradients(), andApplyGradients()are acceptable for tabular Monte Carlo (which doesn't use gradient-based updates), but might confuse users expecting standard gradient-based RL behavior.Optionally add XML doc comments explaining the tabular nature:
+ /// <summary> + /// Tabular Monte Carlo does not require explicit training steps. + /// Updates occur automatically in StoreExperience when episodes complete. + /// </summary> public override T Train() { return NumOps.Zero; }Also applies to: 324-329
src/ReinforcementLearning/Agents/MADDPG/MADDPGAgent.cs (3)
60-76: Optimizer is configured but unused in training; Clone reuses the same instanceThe constructor wires
_optimizerwith a learning rate derived from_options.ActorLearningRate, butUpdateCritic/UpdateActordo all updates manually and never invoke_optimizer.Clone()then passes the same_optimizerinstance into the new agent, so both agents share an optimizer object that, if it ever starts carrying state or references the original model, may behave unexpectedly.Consider either:
- Removing the optimizer parameter/field from
MADDPGAgentuntil it is actually used, or- Refactoring the critic/actor updates to go through
_optimizerand ensuringClone()builds a fresh optimizer for the cloned agent.Also applies to: 768-775
169-172: UnusedInitializeReplayBufferhelper
InitializeReplayBuffer()duplicates the constructor logic for_replayBufferbut is never called.Either remove this method or call it from the constructor to avoid dead code and keep initialization in one place.
636-758: Parameter management and cloning now correctly cover all networks
GetParameters/SetParametersnow include actors, critics, and both sets of target networks, andSetParametersre-synchronizes targets from the main networks at the end.Clone()uses these methods to copy trained weights, so clones start with identical behavior.One minor note: since
SetParametersalready loads explicit target parameters, the final synchronization step makes target entries in the serialized vector effectively redundant. If you don’t intend to support independently-parameterized targets, you could simplify by omitting target segments from the vector and relying solely on the post-copy sync.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/ReinforcementLearning/Agents/DQN/DQNAgent.cs(1 hunks)src/ReinforcementLearning/Agents/MADDPG/MADDPGAgent.cs(1 hunks)src/ReinforcementLearning/Agents/MonteCarlo/FirstVisitMonteCarloAgent.cs(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
src/ReinforcementLearning/Agents/MonteCarlo/FirstVisitMonteCarloAgent.cs (3)
src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs (3)
ReinforcementLearningAgentBase(34-415)ReinforcementLearningAgentBase(90-111)Train(170-175)src/Models/Options/MonteCarloOptions.cs (1)
MonteCarloOptions(9-13)src/ReinforcementLearning/Agents/MonteCarlo/EveryVisitMonteCarloAgent.cs (3)
VectorToStateKey(106-114)UpdateFromEpisode(75-97)EnsureStateExists(141-151)
src/ReinforcementLearning/Agents/MADDPG/MADDPGAgent.cs (5)
src/ReinforcementLearning/Agents/DeepReinforcementLearningAgentBase.cs (2)
DeepReinforcementLearningAgentBase(33-105)DeepReinforcementLearningAgentBase(65-69)src/Models/Options/MADDPGOptions.cs (3)
MADDPGOptions(33-73)MADDPGOptions(53-59)Validate(64-72)src/ReinforcementLearning/ReplayBuffers/UniformReplayBuffer.cs (5)
List(62-82)List(93-115)UniformReplayBuffer(18-129)UniformReplayBuffer(35-44)Add(47-59)src/NeuralNetworks/NeuralNetwork.cs (3)
Tensor(179-197)Train(228-262)UpdateParameters(141-154)src/Helpers/MathHelper.cs (1)
MathHelper(19-997)
src/ReinforcementLearning/Agents/DQN/DQNAgent.cs (5)
src/Models/Options/DQNOptions.cs (2)
DQNOptions(9-32)DQNOptions(26-31)src/ReinforcementLearning/ReplayBuffers/UniformReplayBuffer.cs (6)
UniformReplayBuffer(18-129)UniformReplayBuffer(35-44)Add(47-59)List(62-82)List(93-115)CanSample(118-121)src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs (2)
ReinforcementLearningOptions(421-487)Train(170-175)src/NeuralNetworks/FeedForwardNeuralNetwork.cs (5)
Tensor(131-146)Tensor(166-178)Tensor(199-208)Train(258-292)UpdateParameters(226-239)src/NeuralNetworks/NeuralNetwork.cs (3)
Tensor(179-197)Train(228-262)UpdateParameters(141-154)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Build All Frameworks
🔇 Additional comments (13)
src/ReinforcementLearning/Agents/MonteCarlo/FirstVisitMonteCarloAgent.cs (9)
42-56: Constructor properly fixed.The Random initialization now correctly uses the base class's
Randomfield, preserving seeded/deterministic behavior as intended.
58-75: LGTM!Epsilon-greedy action selection is correctly implemented with proper exploration during training and greedy exploitation otherwise.
77-127: First-visit Monte Carlo implementation is correct.The episode storage and update logic properly implements the first-visit constraint using a visited set, and Q-values are correctly computed as the average of observed returns.
129-143: ComputeAverage correctly implemented.The helper method is now present and correctly computes the average using NumOps operations.
150-158: State key precision correctly updated.Now uses F8 precision as suggested, matching other Monte Carlo agents and avoiding state aliasing issues.
172-199: LGTM!Both helper methods are correctly implemented:
EnsureStateExistsproperly initializes all action values, andGetBestActioncorrectly identifies the optimal action.
268-298: SetParameters correctly fixed.The method now captures state keys before any modifications and preserves the Q-table structure while updating values, addressing the previous critical issue where the table was being wiped.
300-322: Clone properly deep-copies all data structures.Both
_qTableand_returnsare now correctly deep-copied at all levels, ensuring clones are fully isolated from the source agent.
331-356: LGTM!Model persistence methods are correctly implemented with proper validation and error handling.
src/ReinforcementLearning/Agents/DQN/DQNAgent.cs (4)
57-101: Constructor null handling looks good.The null check has been properly moved to the
CreateBaseOptionshelper method, which is called in the base constructor initializer. This ensures that anArgumentNullExceptionis thrown before any property access on a nulloptionsargument.
329-334: Target network synchronization looks good.
SetParametersnow correctly synchronizes the target network after updating the Q-network parameters (Lines 332-333). This ensures that both networks remain consistent after deserialization, cloning, or external parameter updates.
382-405: ApplyGradients validation and implementation looks correct.The method now includes proper validation (Lines 387-394) to ensure the gradient vector length matches the parameter vector length, with a clear error message explaining the requirement for parameter-space gradients. The gradient application logic is correct, and calling
SetParametersensures both the Q-network and target network are updated.
221-231: The code is correct—no issues found.The
GetGradients()method exists inNeuralNetworkBase.cs(line 2133) and returnsVector<T>containing parameter-space gradients collected from each layer vialayer.GetParameterGradients(). All method signatures are compatible:GetParameters()andGetGradients()both returnVector<T>, which supports indexing and the.Lengthproperty. The gradient update loop correctly implements standard SGD parameter updates.
| private INeuralNetwork<T> CreateActorNetwork() | ||
| { | ||
| // Create layers | ||
| var layers = new List<ILayer<T>>(); | ||
|
|
||
| // Input layer | ||
| layers.Add(new DenseLayer<T>(_options.StateSize, _options.ActorHiddenLayers[0], (IActivationFunction<T>)new ReLUActivation<T>())); | ||
|
|
||
| // Hidden layers | ||
| for (int i = 1; i < _options.ActorHiddenLayers.Count; i++) | ||
| { | ||
| layers.Add(new DenseLayer<T>(_options.ActorHiddenLayers[i - 1], _options.ActorHiddenLayers[i], (IActivationFunction<T>)new ReLUActivation<T>())); | ||
| } | ||
|
|
||
| // Output layer with Tanh for continuous actions | ||
| // Issue #1 fix: DenseLayer constructor automatically applies Xavier/Glorot weight initialization | ||
| layers.Add(new DenseLayer<T>(_options.ActorHiddenLayers.Last(), _options.ActionSize, (IActivationFunction<T>)new TanhActivation<T>())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hidden-layer lists are assumed non-empty; invalid config will throw at construction
Both CreateActorNetwork and CreateCriticNetwork index ActorHiddenLayers[0]/CriticHiddenLayers[0] and iterate starting from 1. If a caller provides an empty list for either, the agent will throw IndexOutOfRangeException during construction, and Validate() doesn’t guard against this.
Consider:
- Extending
MADDPGOptions.Validate()to requireActorHiddenLayers.Count > 0andCriticHiddenLayers.Count > 0, or - Providing sensible defaults when the lists are empty (e.g., inject a default
[128,128]).
Also applies to: 146-153
| /// <summary> | ||
| /// Store multi-agent experience with per-agent reward tracking. | ||
| /// </summary> | ||
| /// <remarks> | ||
| /// Stores individual rewards for each agent to support both cooperative and | ||
| /// competitive/mixed-motive scenarios. For backward compatibility, also stores | ||
| /// an averaged reward in the experience. | ||
| /// | ||
| /// The per-agent rewards are stored keyed by the buffer index where the experience | ||
| /// will be placed. This accounts for the circular buffer behavior when capacity is reached. | ||
| /// </remarks> | ||
| public void StoreMultiAgentExperience( | ||
| List<Vector<T>> states, | ||
| List<Vector<T>> actions, | ||
| List<T> rewards, | ||
| List<Vector<T>> nextStates, | ||
| bool done) | ||
| { | ||
| // Concatenate all agents' observations for centralized storage | ||
| var jointState = ConcatenateVectors(states); | ||
| var jointAction = ConcatenateVectors(actions); | ||
| var jointNextState = ConcatenateVectors(nextStates); | ||
|
|
||
| // Compute the buffer index where this experience will be stored | ||
| // This accounts for circular buffer behavior: if buffer is not full, index = Count | ||
| // If buffer is full, the circular position is used (which we approximate here) | ||
| int bufferIndex; | ||
| if (_replayBuffer.Count < _replayBuffer.Capacity) | ||
| { | ||
| // Buffer not full yet, experience goes at the end | ||
| bufferIndex = _replayBuffer.Count; | ||
| } | ||
| else | ||
| { | ||
| // Buffer is full, circular overwrite - use modulo to find position | ||
| // Note: We approximate the position since we don't have access to internal _position field | ||
| // This works because experiences are added sequentially | ||
| bufferIndex = _stepCount % _replayBuffer.Capacity; | ||
| } | ||
|
|
||
| // Store per-agent rewards at the buffer index for competitive/mixed-motive scenarios | ||
| _perAgentRewards[bufferIndex] = new List<T>(rewards); | ||
|
|
||
| // Also compute average reward for cooperative scenarios (backward compatibility) | ||
| T avgReward = NumOps.Zero; | ||
| foreach (var reward in rewards) | ||
| { | ||
| avgReward = NumOps.Add(avgReward, reward); | ||
| } | ||
| avgReward = NumOps.Divide(avgReward, NumOps.FromDouble(rewards.Count)); | ||
|
|
||
| _replayBuffer.Add(new ReplayBuffers.Experience<T>(jointState, jointAction, avgReward, jointNextState, done)); | ||
| _stepCount++; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Per-agent reward bookkeeping breaks if StoreExperience is used alongside StoreMultiAgentExperience
StoreMultiAgentExperience computes a buffer index from _replayBuffer.Count/_stepCount and writes _perAgentRewards[bufferIndex] = ... before calling _replayBuffer.Add(...). This is consistent with the circular buffer behavior.
However, StoreExperience also calls _replayBuffer.Add(...) and increments _stepCount, but does not update _perAgentRewards. If a single-agent StoreExperience call overwrites an entry that previously came from StoreMultiAgentExperience, _perAgentRewards will still hold the old per-agent rewards for that buffer index. Later, UpdateCritic will see _perAgentRewards.ContainsKey(bufferIndex) and use stale per-agent rewards for an experience that no longer matches, silently corrupting training.
To make the bookkeeping robust, you can, for example, clear per-agent entries for single-agent inserts:
public override void StoreExperience(Vector<T> state, Vector<T> action, T reward, Vector<T> nextState, bool done)
{
- _replayBuffer.Add(new ReplayBuffers.Experience<T>(state, action, reward, nextState, done));
- _stepCount++;
+ // Compute the buffer index this experience will occupy (mirror multi-agent path)
+ int bufferIndex;
+ if (_replayBuffer.Count < _replayBuffer.Capacity)
+ {
+ bufferIndex = _replayBuffer.Count;
+ }
+ else
+ {
+ bufferIndex = _stepCount % _replayBuffer.Capacity;
+ }
+
+ // Remove any stale per-agent rewards for this slot
+ _perAgentRewards.Remove(bufferIndex);
+
+ _replayBuffer.Add(new ReplayBuffers.Experience<T>(state, action, reward, nextState, done));
+ _stepCount++;
}Alternatively, move per-agent rewards into the experience structure itself to avoid external index bookkeeping.
Also applies to: 263-267
🤖 Prompt for AI Agents
In src/ReinforcementLearning/Agents/MADDPG/MADDPGAgent.cs around lines 208-261
(and also 263-267), StoreMultiAgentExperience writes per-agent rewards into the
external _perAgentRewards dictionary using a computed bufferIndex but
StoreExperience also overwrites the replay buffer without touching
_perAgentRewards, leaving stale per-agent data that can corrupt training; fix by
either (preferred) moving per-agent reward storage into the
ReplayBuffers.Experience<T> object (update the Experience type, its
constructors, and all call sites including StoreMultiAgentExperience and
StoreExperience, and consume per-agent rewards from the experience in
UpdateCritic), or (simpler) ensure StoreExperience clears or updates
_perAgentRewards for the buffer index it will overwrite (compute the same
bufferIndex logic used in StoreMultiAgentExperience before calling
_replayBuffer.Add and either remove the key or set it to null/empty), and add
unit tests to cover mixed calls of StoreExperience and StoreMultiAgentExperience
to prevent regressions.
| private T UpdateActor(int agentId, List<AiDotNet.ReinforcementLearning.ReplayBuffers.Experience<T>> batch) | ||
| { | ||
| T totalLoss = NumOps.Zero; | ||
|
|
||
| foreach (var experience in batch) | ||
| { | ||
| // Decompose joint state to get this agent's state | ||
| int stateOffset = agentId * _options.StateSize; | ||
| var agentState = new Vector<T>(_options.StateSize); | ||
| for (int i = 0; i < _options.StateSize; i++) | ||
| { | ||
| agentState[i] = experience.State[stateOffset + i]; | ||
| } | ||
|
|
||
| // Compute action from actor | ||
| var agentStateTensor = Tensor<T>.FromVector(agentState); | ||
| var actionTensor = _actorNetworks[agentId].Predict(agentStateTensor); | ||
| var action = actionTensor.ToVector(); | ||
|
|
||
| // Reconstruct joint action with this agent's new action | ||
| var jointAction = experience.Action.Clone(); | ||
| for (int i = 0; i < _options.ActionSize; i++) | ||
| { | ||
| jointAction[agentId * _options.ActionSize + i] = action[i]; | ||
| } | ||
|
|
||
| // Compute Q-value from critic (for deterministic policy gradient) | ||
| var jointStateAction = ConcatenateStateAction(experience.State, jointAction); | ||
| var jointStateActionTensor = Tensor<T>.FromVector(jointStateAction); | ||
| var qValueTensor = _criticNetworks[agentId].Predict(jointStateActionTensor); | ||
| var qValue = qValueTensor.ToVector()[0]; | ||
|
|
||
| // Actor loss: maximize Q-value (negated for minimization) | ||
| totalLoss = NumOps.Add(totalLoss, NumOps.Negate(qValue)); | ||
|
|
||
| // Deterministic Policy Gradient: backprop through critic to get dQ/dAction | ||
| // Create upstream gradient for critic output (dLoss/dQ = -1 for maximization) | ||
| var criticOutputGradient = new Vector<T>(1); | ||
| criticOutputGradient[0] = NumOps.FromDouble(-1.0); // Negative because we want to maximize Q | ||
| var criticOutputGradientTensor = Tensor<T>.FromVector(criticOutputGradient); | ||
|
|
||
| // Backpropagate through critic to compute gradients w.r.t. its input | ||
| // Note: This computes dQ/d(state,action) internally in the network layers | ||
| if (_criticNetworks[agentId] is NeuralNetwork<T> criticNetwork) | ||
| { | ||
| // Backpropagate returns gradients w.r.t. network input | ||
| var inputGradientsTensor = criticNetwork.Backpropagate(criticOutputGradientTensor); | ||
| var inputGradients = inputGradientsTensor.ToVector(); | ||
|
|
||
| // The input to critic is [state, action] concatenated | ||
| // Extract dQ/dAction for this specific agent | ||
| // Action gradients start after all states: jointStateSize | ||
| // This agent's actions are at: jointStateSize + (agentId * _options.ActionSize) | ||
| int jointStateSize = experience.State.Length; | ||
| int jointActionSize = _options.ActionSize * _options.NumAgents; | ||
| var actionGradient = new Vector<T>(_options.ActionSize); | ||
|
|
||
| for (int i = 0; i < _options.ActionSize; i++) | ||
| { | ||
| // Extract gradients for this agent's action from joint action space | ||
| int actionGradientIdx = jointStateSize + (agentId * _options.ActionSize + i); | ||
| if (actionGradientIdx < inputGradients.Length) | ||
| { | ||
| actionGradient[i] = inputGradients[actionGradientIdx]; | ||
| } | ||
| else | ||
| { | ||
| // Fallback: use simple gradient estimate | ||
| actionGradient[i] = NumOps.Divide(criticOutputGradient[0], NumOps.FromDouble(_options.ActionSize)); | ||
| } | ||
| } | ||
|
|
||
| // Backpropagate action gradient through actor to get parameter gradients | ||
| var actionGradientTensor = Tensor<T>.FromVector(actionGradient); | ||
| if (_actorNetworks[agentId] is NeuralNetwork<T> actorNetwork) | ||
| { | ||
| actorNetwork.Backpropagate(actionGradientTensor); | ||
|
|
||
| // Extract parameter gradients from actor network | ||
| var parameterGradients = actorNetwork.GetGradients(); | ||
| var actorParams = actorNetwork.GetParameters(); | ||
|
|
||
| // Gradient ascent: θ ← θ + α * ∇_θ J (maximize Q) | ||
| for (int i = 0; i < actorParams.Length && i < parameterGradients.Length; i++) | ||
| { | ||
| var update = NumOps.Multiply(_options.ActorLearningRate, parameterGradients[i]); | ||
| actorParams[i] = NumOps.Add(actorParams[i], update); // Add for ascent | ||
| } | ||
| actorNetwork.UpdateParameters(actorParams); | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actor gradient sign is reversed; current update tends to minimize Q instead of maximizing it
In UpdateActor, you set:
criticOutputGradient[0] = NumOps.FromDouble(-1.0); // dLoss/dQ = -1
...
actorNetwork.Backpropagate(actionGradientTensor);
var parameterGradients = actorNetwork.GetGradients();
...
actorParams[i] = NumOps.Add(actorParams[i], update); // update = α * parameterGradients[i]Given that Backpropagate+GetGradients() return dLoss/dθ, this update performs gradient ascent on the loss (θ ← θ + α dLoss/dθ), which, with Loss = -Q, actually decreases Q.
To implement the deterministic policy gradient correctly, treat Q itself as the objective and propagate dQ/dAction:
- // Actor loss: maximize Q-value (negated for minimization)
- totalLoss = NumOps.Add(totalLoss, NumOps.Negate(qValue));
-
- // Deterministic Policy Gradient: backprop through critic to get dQ/dAction
- // Create upstream gradient for critic output (dLoss/dQ = -1 for maximization)
- var criticOutputGradient = new Vector<T>(1);
- criticOutputGradient[0] = NumOps.FromDouble(-1.0); // Negative because we want to maximize Q
+ // Track -Q for logging, but treat Q as the optimization objective
+ totalLoss = NumOps.Add(totalLoss, NumOps.Negate(qValue));
+
+ // Deterministic Policy Gradient: backprop through critic to get dQ/dAction
+ // Upstream gradient for critic output is +1 so Backpropagate returns dQ/d(input)
+ var criticOutputGradient = new Vector<T>(1);
+ criticOutputGradient[0] = NumOps.One;
@@
- // Gradient ascent: θ ← θ + α * ∇_θ J (maximize Q)
+ // Gradient ascent on Q: θ ← θ + α * ∇_θ Q
for (int i = 0; i < actorParams.Length && i < parameterGradients.Length; i++)
{
var update = NumOps.Multiply(_options.ActorLearningRate, parameterGradients[i]);
actorParams[i] = NumOps.Add(actorParams[i], update); // Add for ascent
}This way parameterGradients represent ∂Q/∂θ and the + update correctly increases Q.
🤖 Prompt for AI Agents
In src/ReinforcementLearning/Agents/MADDPG/MADDPGAgent.cs around lines 377 to
467, the actor update is using a negative critic output gradient and also
negating the Q-value, which causes parameter updates to reduce Q instead of
increase it; fix by treating Q as the objective: stop negating the qValue when
accumulating totalLoss (add qValue rather than NumOps.Negate(qValue)) and set
criticOutputGradient[0] = NumOps.FromDouble(1.0) so backprop computes
∂Q/∂action, leaving the parameter update as actorParams[i] =
NumOps.Add(actorParams[i], α * parameterGradients[i]) to perform gradient ascent
on Q.
| public override byte[] Serialize() | ||
| { | ||
| var state = new | ||
| { | ||
| QTable = _qTable, | ||
| Returns = _returns, | ||
| Epsilon = _epsilon, | ||
| Options = _options | ||
| }; | ||
| string json = JsonConvert.SerializeObject(state); | ||
| return System.Text.Encoding.UTF8.GetBytes(json); | ||
| } | ||
|
|
||
| public override void Deserialize(byte[] data) | ||
| { | ||
| if (data is null || data.Length == 0) | ||
| { | ||
| throw new ArgumentException("Serialized data cannot be null or empty", nameof(data)); | ||
| } | ||
|
|
||
| string json = System.Text.Encoding.UTF8.GetString(data); | ||
| var state = JsonConvert.DeserializeObject<dynamic>(json); | ||
| if (state is null) | ||
| { | ||
| throw new InvalidOperationException("Deserialization returned null"); | ||
| } | ||
|
|
||
| _qTable = JsonConvert.DeserializeObject<Dictionary<string, Dictionary<int, T>>>(state.QTable.ToString()) ?? new Dictionary<string, Dictionary<int, T>>(); | ||
| _returns = JsonConvert.DeserializeObject<Dictionary<string, Dictionary<int, List<T>>>>(state.Returns.ToString()) ?? new Dictionary<string, Dictionary<int, List<T>>>(); | ||
| _epsilon = state.Epsilon; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Deserialize doesn't restore Options, causing potential mismatches.
While Serialize includes Options in the serialized state (line 225), Deserialize never restores them. After deserialization, _options will retain the value from the constructor, which may differ from the saved configuration. This can cause ActionSize or StateSize mismatches leading to index errors or incorrect behavior.
Apply this diff to restore options:
_qTable = JsonConvert.DeserializeObject<Dictionary<string, Dictionary<int, T>>>(state.QTable.ToString()) ?? new Dictionary<string, Dictionary<int, T>>();
_returns = JsonConvert.DeserializeObject<Dictionary<string, Dictionary<int, List<T>>>>(state.Returns.ToString()) ?? new Dictionary<string, Dictionary<int, List<T>>>();
_epsilon = state.Epsilon;
+ _options = JsonConvert.DeserializeObject<MonteCarloOptions<T>>(state.Options.ToString()) ?? _options;
}🤖 Prompt for AI Agents
In src/ReinforcementLearning/Agents/MonteCarlo/FirstVisitMonteCarloAgent.cs
around lines 218 to 248, Deserialize currently restores QTable, Returns and
Epsilon but omits restoring Options; update Deserialize to read Options from the
deserialized state, cast/convert it back into the concrete options type and
assign it to _options (with a null check and fallback to the existing
constructor value), and then validate that the deserialized Options.ActionSize
and Options.StateSize match any internal structures (or reinitialize dependent
structures like _qTable/_returns if sizes differ) to avoid index mismatches
after deserialization.
User Story / Context
merge-dev2-to-masterSummary
Verification
Copilot Review Loop (Outcome-Based)
Record counts before/after your last push:
Files Modified
Notes