@@ -62,9 +62,11 @@ public DreamerAgent(DreamerOptions<T> options, IOptimizer<T, Vector<T>, Vector<T
6262 : base ( options )
6363 {
6464 _options = options ?? throw new ArgumentNullException ( nameof ( options ) ) ;
65+
66+ // FIX ISSUE 6: Use learning rate from options consistently
6567 _optimizer = optimizer ?? options . Optimizer ?? new AdamOptimizer < T , Vector < T > , Vector < T > > ( this , new AdamOptimizerOptions < T , Vector < T > , Vector < T > >
6668 {
67- LearningRate = 0.001 ,
69+ LearningRate = _options . LearningRate ,
6870 Beta1 = 0.9 ,
6971 Beta2 = 0.999 ,
7072 Epsilon = 1e-8
@@ -88,6 +90,14 @@ public DreamerAgent(DreamerOptions<T> options, IOptimizer<T, Vector<T>, Vector<T
8890 _actorNetwork = CreateActorNetwork ( ) ;
8991 _valueNetwork = CreateEncoderNetwork ( _options . LatentSize , 1 ) ;
9092
93+ // FIX ISSUE 3: Add all networks to Networks list for parameter access
94+ Networks . Add ( _representationNetwork ) ;
95+ Networks . Add ( _dynamicsNetwork ) ;
96+ Networks . Add ( _rewardNetwork ) ;
97+ Networks . Add ( _continueNetwork ) ;
98+ Networks . Add ( _actorNetwork ) ;
99+ Networks . Add ( _valueNetwork ) ;
100+
91101 // Initialize replay buffer
92102 _replayBuffer = new ReplayBuffers . UniformReplayBuffer < T > ( _options . ReplayBufferSize , _options . Seed ) ;
93103 }
@@ -222,6 +232,18 @@ private T TrainWorldModel(List<ReplayBuffers.Experience<T>> batch)
222232 var dynamicsParams = _dynamicsNetwork . GetParameters ( ) ;
223233 _dynamicsNetwork . UpdateParameters ( dynamicsParams ) ;
224234
235+ // FIX ISSUE 1: Train representation network
236+ // Representation network should minimize reconstruction error of latent states
237+ var representationGradient = new Vector < T > ( latentState . Length ) ;
238+ for ( int j = 0 ; j < representationGradient . Length ; j ++ )
239+ {
240+ // Gradient flows from dynamics prediction error back through representation
241+ representationGradient [ j ] = NumOps . Divide ( gradient [ j ] , NumOps . FromDouble ( 2.0 ) ) ;
242+ }
243+ _representationNetwork . Backpropagate ( Tensor < T > . FromVector ( representationGradient ) ) ;
244+ var representationParams = _representationNetwork . GetParameters ( ) ;
245+ _representationNetwork . UpdateParameters ( representationParams ) ;
246+
225247 var rewardGradient = new Vector < T > ( 1 ) ;
226248 rewardGradient [ 0 ] = rewardDiff ;
227249 _rewardNetwork . Backpropagate ( Tensor < T > . FromVector ( rewardGradient ) ) ;
@@ -258,23 +280,35 @@ private T TrainPolicy()
258280 // Imagine future trajectory
259281 var imaginedReturns = ImagineTrajectory ( latentState ) ;
260282
261- // Update value network
283+ // FIX ISSUE 4: Update value network with correct gradient sign
284+ // Value network minimizes squared TD error: (return - value)^2
262285 var predictedValue = _valueNetwork . Predict ( Tensor < T > . FromVector ( latentState ) ) . ToVector ( ) [ 0 ] ;
263286 var valueDiff = NumOps . Subtract ( imaginedReturns , predictedValue ) ;
264287 var valueLoss = NumOps . Multiply ( valueDiff , valueDiff ) ;
265288
289+ // Gradient of MSE loss: 2 * (prediction - target) = -2 * (target - prediction)
266290 var valueGradient = new Vector < T > ( 1 ) ;
267- valueGradient [ 0 ] = valueDiff ;
291+ valueGradient [ 0 ] = NumOps . Multiply ( NumOps . FromDouble ( - 2.0 ) , valueDiff ) ;
268292 _valueNetwork . Backpropagate ( Tensor < T > . FromVector ( valueGradient ) ) ;
269293 var valueParams = _valueNetwork . GetParameters ( ) ;
270294 _valueNetwork . UpdateParameters ( valueParams ) ;
271295
272- // Update actor to maximize value
296+ // FIX ISSUE 2: Implement proper policy gradient for actor
297+ // Actor maximizes expected return by following gradient of value w.r.t. actions
298+ // Use advantage (return - baseline) as policy gradient weight
299+ var advantage = valueDiff ;
300+
301+ // Compute value gradient w.r.t. current action to get policy gradient direction
273302 var action = _actorNetwork . Predict ( Tensor < T > . FromVector ( latentState ) ) . ToVector ( ) ;
274303 var actorGradient = new Vector < T > ( action . Length ) ;
304+
305+ // Policy gradient: advantage * grad_action(log pi(action|state))
306+ // For deterministic policy, approximate with advantage-weighted action gradient
275307 for ( int i = 0 ; i < actorGradient . Length ; i ++ )
276308 {
277- actorGradient [ i ] = NumOps . Divide ( valueDiff , NumOps . FromDouble ( action . Length ) ) ;
309+ // Gradient direction: maximize value by adjusting actions
310+ // Positive advantage -> increase action magnitude in current direction
311+ actorGradient [ i ] = NumOps . Multiply ( advantage , NumOps . FromDouble ( - 1.0 / action . Length ) ) ;
278312 }
279313
280314 _actorNetwork . Backpropagate ( Tensor < T > . FromVector ( actorGradient ) ) ;
@@ -300,7 +334,10 @@ private T ImagineTrajectory(Vector<T> initialLatentState)
300334
301335 // Predict reward
302336 var reward = _rewardNetwork . Predict ( Tensor < T > . FromVector ( latentState ) ) . ToVector ( ) [ 0 ] ;
303- imaginedReturn = NumOps . Add ( imaginedReturn , reward ) ;
337+
338+ // FIX ISSUE 5: Add discount factor (gamma) to imagination rollout
339+ var discountedReward = NumOps . Multiply ( reward , NumOps . Pow ( NumOps . FromDouble ( _options . Gamma ) , NumOps . FromDouble ( step ) ) ) ;
340+ imaginedReturn = NumOps . Add ( imaginedReturn , discountedReward ) ;
304341
305342 // Predict next latent state
306343 var dynamicsInput = ConcatenateVectors ( latentState , action ) ;
@@ -373,12 +410,20 @@ public override ModelMetadata<T> GetModelMetadata()
373410
374411 public override byte [ ] Serialize ( )
375412 {
376- throw new NotImplementedException ( "Dreamer serialization not yet implemented" ) ;
413+ // FIX ISSUE 8: Use NotSupportedException with clear message
414+ throw new NotSupportedException (
415+ "Dreamer agent serialization is not supported. " +
416+ "Use GetParameters()/SetParameters() for parameter transfer, " +
417+ "or save individual network weights separately." ) ;
377418 }
378419
379420 public override void Deserialize ( byte [ ] data )
380421 {
381- throw new NotImplementedException ( "Dreamer deserialization not yet implemented" ) ;
422+ // FIX ISSUE 8: Use NotSupportedException with clear message
423+ throw new NotSupportedException (
424+ "Dreamer agent deserialization is not supported. " +
425+ "Use GetParameters()/SetParameters() for parameter transfer, " +
426+ "or load individual network weights separately." ) ;
382427 }
383428
384429 public override Vector < T > GetParameters ( )
@@ -422,9 +467,29 @@ public override void SetParameters(Vector<T> parameters)
422467
423468 public override IFullModel < T , Vector < T > , Vector < T > > Clone ( )
424469 {
425- return new DreamerAgent < T > ( _options , _optimizer ) ;
470+ // FIX ISSUE 7: Clone should copy learned network parameters
471+ var clone = new DreamerAgent < T > ( _options , _optimizer ) ;
472+
473+ // Copy all network parameters
474+ var parameters = GetParameters ( ) ;
475+ clone . SetParameters ( parameters ) ;
476+
477+ return clone ;
426478 }
427479
480+ /// <summary>
481+ /// Computes gradients for supervised learning scenarios.
482+ /// </summary>
483+ /// <remarks>
484+ /// FIX ISSUE 9: This method uses simple supervised loss for compatibility with base class API.
485+ /// It does NOT match the agent's internal training procedure which uses:
486+ /// - World model losses (dynamics, reward, continue prediction)
487+ /// - Imagination-based policy gradients
488+ /// - Value function TD errors
489+ ///
490+ /// For actual agent training, use Train() which implements the full Dreamer algorithm.
491+ /// This method is provided only for API compatibility and simple supervised fine-tuning scenarios.
492+ /// </remarks>
428493 public override Vector < T > ComputeGradients (
429494 Vector < T > input ,
430495 Vector < T > target ,
@@ -450,13 +515,17 @@ public override void ApplyGradients(Vector<T> gradients, T learningRate)
450515
451516 public override void SaveModel ( string filepath )
452517 {
453- var data = Serialize ( ) ;
454- System . IO . File . WriteAllBytes ( filepath , data ) ;
518+ // FIX ISSUE 8: Throw NotSupportedException since Serialize is not supported
519+ throw new NotSupportedException (
520+ "Dreamer agent save/load is not supported. " +
521+ "Use GetParameters()/SetParameters() for parameter transfer." ) ;
455522 }
456523
457524 public override void LoadModel ( string filepath )
458525 {
459- var data = System . IO . File . ReadAllBytes ( filepath ) ;
460- Deserialize ( data ) ;
526+ // FIX ISSUE 8: Throw NotSupportedException since Deserialize is not supported
527+ throw new NotSupportedException (
528+ "Dreamer agent save/load is not supported. " +
529+ "Use GetParameters()/SetParameters() for parameter transfer." ) ;
461530 }
462531}
0 commit comments