Commit 2bc248f
committed
fix: remove double-scaling of distributionWeight in ProbabilisticDistillationStrategy
The standard distillation term was being scaled by (1 - _distributionWeight) twice:
1. Once when computing softLoss (line 63)
2. Again when multiplying combinedLoss (line 73 for trueLabels case)
This double-scaling incorrectly reduced the soft component by (1-distributionWeight)^2
instead of (1-distributionWeight).
Fixed by:
ComputeLoss:
- Removed distributionWeight from initial softLoss scaling (line 63)
- Computing finalLoss (either combinedLoss or softLoss)
- Applying (1.0 - _distributionWeight) scaling exactly once at the end
ComputeGradient:
- Removed distributionWeight from soft gradient scaling
- Removed distributionWeight from hard gradient blending
- Computing combined gradient: Alpha * hardGrad + (1 - Alpha) * softGrad
- Applying (1.0 - _distributionWeight) scaling exactly once per element
Now distributionWeight correctly balances:
- (1 - distributionWeight) * standard_distillation
- distributionWeight * distributional_matching
Note: KL divergence already uses correct direction KLDivergence(teacherSoft, studentSoft)
which computes KL(teacher || student) matching the gradient computation.1 parent c87817d commit 2bc248f
File tree
1 file changed
+35
-14
lines changed- src/KnowledgeDistillation/Strategies
1 file changed
+35
-14
lines changedLines changed: 35 additions & 14 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
56 | 56 | | |
57 | 57 | | |
58 | 58 | | |
59 | | - | |
| 59 | + | |
60 | 60 | | |
61 | 61 | | |
62 | 62 | | |
63 | | - | |
| 63 | + | |
64 | 64 | | |
| 65 | + | |
65 | 66 | | |
66 | 67 | | |
67 | 68 | | |
68 | 69 | | |
69 | 70 | | |
70 | | - | |
| 71 | + | |
71 | 72 | | |
72 | 73 | | |
73 | | - | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
74 | 78 | | |
75 | 79 | | |
76 | | - | |
| 80 | + | |
| 81 | + | |
77 | 82 | | |
78 | 83 | | |
79 | 84 | | |
| |||
86 | 91 | | |
87 | 92 | | |
88 | 93 | | |
89 | | - | |
90 | | - | |
91 | | - | |
92 | | - | |
93 | | - | |
94 | | - | |
95 | 94 | | |
96 | 95 | | |
97 | 96 | | |
98 | 97 | | |
99 | 98 | | |
100 | 99 | | |
101 | 100 | | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
102 | 106 | | |
103 | | - | |
104 | | - | |
105 | | - | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
106 | 127 | | |
107 | 128 | | |
108 | 129 | | |
| |||
0 commit comments