File tree Expand file tree Collapse file tree 1 file changed +9
-7
lines changed
Benchmarks/BuildingSimulation/TensorFlow Expand file tree Collapse file tree 1 file changed +9
-7
lines changed Original file line number Diff line number Diff line change @@ -227,6 +227,15 @@ def fullPipe(simParams):
227227 return loss
228228
229229
230+ @tf .function
231+ def getGradient (simParams ):
232+ with tf .GradientTape () as tape :
233+ endTemperature = simulate (simParams )
234+
235+ gradient = tape .gradient (endTemperature , [simParams ])
236+ return gradient
237+
238+
230239totalForwardTime = 0
231240totalGradientTime = 0
232241
@@ -235,13 +244,6 @@ def fullPipe(simParams):
235244 forwardTime , forwardOutput = measure (fullPipe , SimParamsConstant )
236245
237246 simParams = tf .Variable (SimParamsConstant )
238- def getGradient (simParams ):
239- with tf .GradientTape () as tape :
240- endTemperature = simulate (simParams )
241-
242- gradient = tape .gradient (endTemperature , [simParams ])
243- return gradient
244-
245247
246248 gradientTime , gradient = measure (getGradient , simParams )
247249
You can’t perform that action at this time.
0 commit comments