Skip to content

Commit f10da03

Browse files
committed
fix: optimize loss scaler and fix encapsulation violations
- Combine unscaling and overflow check into single pass for better cache locality (line 292) - Replace protected _data field access with public GetFlatIndexValue/SetFlatIndex methods (lines 204, 243) - Improves performance by reducing passes over gradient data - Maintains proper encapsulation by using public Tensor API Addresses review comments from copilot-pull-request-reviewer
1 parent 353a625 commit f10da03

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

src/MixedPrecision/LossScaler.cs

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,8 @@ public void UnscaleGradients(Tensor<T> gradients)
201201

202202
for (int i = 0; i < gradients.Length; i++)
203203
{
204-
T scaledValue = gradients._data[i];
205-
gradients._data[i] = _numOps.Multiply(scaledValue, inverseScale);
204+
T scaledValue = gradients.GetFlatIndexValue(i);
205+
gradients.SetFlatIndex(i, _numOps.Multiply(scaledValue, inverseScale));
206206
}
207207
}
208208

@@ -240,7 +240,7 @@ public bool DetectOverflow(Tensor<T> gradients)
240240
{
241241
for (int i = 0; i < gradients.Length; i++)
242242
{
243-
if (HasOverflow(gradients._data[i]))
243+
if (HasOverflow(gradients.GetFlatIndexValue(i)))
244244
{
245245
return true;
246246
}
@@ -284,11 +284,20 @@ public bool UnscaleGradientsAndCheck(Tensor<T> gradients)
284284
{
285285
_totalUpdates++;
286286

287-
// First unscale the gradients
288-
UnscaleGradients(gradients);
287+
// Combine unscaling and overflow check in single pass for better cache locality
288+
T inverseScale = _numOps.FromDouble(1.0 / Scale);
289+
bool hasOverflow = false;
289290

290-
// Check for overflow
291-
bool hasOverflow = DetectOverflow(gradients);
291+
for (int i = 0; i < gradients.Length; i++)
292+
{
293+
T unscaled = _numOps.Multiply(gradients.GetFlatIndexValue(i), inverseScale);
294+
gradients.SetFlatIndex(i, unscaled);
295+
296+
if (!hasOverflow && (_numOps.IsNaN(unscaled) || _numOps.IsInfinity(unscaled)))
297+
{
298+
hasOverflow = true;
299+
}
300+
}
292301

293302
if (hasOverflow)
294303
{

0 commit comments

Comments
 (0)