Skip to content

Commit a568592

Browse files
ooplesclaude
andcommitted
fix: implement SaveState/LoadState for DecisionTreeAsyncRegressionBase
Resolves review comment on DecisionTreeAsyncRegressionBase.cs:1020 - Implemented SaveState by delegating to Serialize() with proper stream validation - Implemented LoadState by delegating to Deserialize() with proper stream validation - Added null checks, stream capability validation, and empty stream validation 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 00af3f8 commit a568592

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

src/Regression/DecisionTreeAsyncRegressionBase.cs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,14 +1008,24 @@ public virtual void ApplyGradients(Vector<T> gradients, T learningRate)
10081008
/// </summary>
10091009
public virtual void SaveState(Stream stream)
10101010
{
1011-
throw new NotImplementedException("SaveState is not yet implemented for AsyncDecisionTreeRegressionBase. Consider serializing the tree structure explicitly.");
1011+
if (stream == null) throw new ArgumentNullException(nameof(stream));
1012+
if (!stream.CanWrite) throw new ArgumentException("Stream must be writable.", nameof(stream));
1013+
var data = Serialize();
1014+
stream.Write(data, 0, data.Length);
1015+
stream.Flush();
10121016
}
10131017

10141018
/// <summary>
10151019
/// Loads the model's state from a stream.
10161020
/// </summary>
10171021
public virtual void LoadState(Stream stream)
10181022
{
1019-
throw new NotImplementedException("LoadState is not yet implemented for AsyncDecisionTreeRegressionBase. Consider deserializing the tree structure explicitly.");
1023+
if (stream == null) throw new ArgumentNullException(nameof(stream));
1024+
if (!stream.CanRead) throw new ArgumentException("Stream must be readable.", nameof(stream));
1025+
using var ms = new MemoryStream();
1026+
stream.CopyTo(ms);
1027+
var data = ms.ToArray();
1028+
if (data.Length == 0) throw new InvalidOperationException("Stream contains no data.");
1029+
Deserialize(data);
10201030
}
10211031
}

0 commit comments

Comments
 (0)