Skip to content

Commit cc2558b

Browse files
committed
add option to print gradient for verification
1 parent 6584149 commit cc2558b

File tree

3 files changed

+11
-0
lines changed

3 files changed

+11
-0
lines changed

Benchmarks/BuildingSimulation/PyTorch/PyTorchSimulator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ def fullPipe(simParams):
177177
timesteps = 20
178178
trials = 30
179179
warmup = 3
180+
printGradToCompare = False
180181

181182
for i in range(trials):
182183

@@ -190,6 +191,9 @@ def getGradient(simParams):
190191

191192

192193
gradientTime, gradient = measure(getGradient, simParams)
194+
195+
if printGradToCompare:
196+
print(gradient)
193197

194198
if i >= warmup:
195199
totalForwardTime += forwardTime

Benchmarks/BuildingSimulation/Swift/main.swift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ var trials = 30
220220
var timesteps = 20
221221
var totalPureForwardTime: Double = 0
222222
var totalGradientTime: Double = 0
223+
let printGradToCompare = false
223224

224225
for _ in 0 ..< trials {
225226
let (forwardOnly, _) = try measure {
@@ -232,6 +233,8 @@ for _ in 0 ..< trials {
232233
}
233234
dontLetTheCompilerOptimizeThisAway(grad)
234235

236+
if printGradToCompare {
237+
print(grad)
235238
}
236239

237240
totalPureForwardTime += forwardOnly

Benchmarks/BuildingSimulation/TensorFlow/TensorFlowSimulator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ def fullPipe(simParams):
227227
timesteps = 20
228228
trials = 30
229229
warmup = 3
230+
printGradToCompare = False
230231

231232
for i in range(trials + warmup):
232233

@@ -243,6 +244,9 @@ def getGradient(simParams):
243244

244245
gradientTime, gradient = measure(getGradient, simParams)
245246

247+
if printGradToCompare:
248+
print(gradient)
249+
246250
if i >= warmup:
247251
totalForwardTime += forwardTime
248252
totalGradientTime += gradientTime

0 commit comments

Comments
 (0)