Skip to content

Commit cdafec7

Browse files
clackaryjfilling
authored andcommitted
refactor: Move getGradient out of control loop
1 parent 51b2d6b commit cdafec7

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

Benchmarks/BuildingSimulation/TensorFlow/TensorFlowSimulator.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,15 @@ def fullPipe(simParams):
227227
return loss
228228

229229

230+
@tf.function
231+
def getGradient(simParams):
232+
with tf.GradientTape() as tape:
233+
endTemperature = simulate(simParams)
234+
235+
gradient = tape.gradient(endTemperature, [simParams])
236+
return gradient
237+
238+
230239
totalForwardTime = 0
231240
totalGradientTime = 0
232241

@@ -235,13 +244,6 @@ def fullPipe(simParams):
235244
forwardTime, forwardOutput = measure(fullPipe, SimParamsConstant)
236245

237246
simParams = tf.Variable(SimParamsConstant)
238-
def getGradient(simParams):
239-
with tf.GradientTape() as tape:
240-
endTemperature = simulate(simParams)
241-
242-
gradient = tape.gradient(endTemperature, [simParams])
243-
return gradient
244-
245247

246248
gradientTime, gradient = measure(getGradient, simParams)
247249

0 commit comments

Comments
 (0)