File tree Expand file tree Collapse file tree 1 file changed +6
-5
lines changed
Benchmarks/BuildingSimulation/PyTorch Expand file tree Collapse file tree 1 file changed +6
-5
lines changed Original file line number Diff line number Diff line change @@ -162,7 +162,7 @@ def measure(function, arguments):
162162 start = time .time ()
163163 result = function (arguments )
164164 end = time .time ()
165- return end - start
165+ return ( end - start , result )
166166
167167
168168def fullPipe (simParams ):
@@ -180,18 +180,19 @@ def fullPipe(simParams):
180180
181181for i in range (trials ):
182182
183- forwardOnly = measure (fullPipe , SimParamsConstant )
183+ inputs = SimParamsConstant
184+ forwardOnlyTime , forwardOutput = measure (fullPipe , inputs )
184185
185186 simParams = SimParamsConstant
186187 def getGradient (simParams ):
187- endTemperature , gradient = torch .autograd .functional . vjp ( simulate , SimParamsConstant )
188+ gradient = torch .autograd .grad ( forwardOutput , inputs )
188189 return gradient
189190
190191
191- gradientTime = measure (getGradient , simParams )
192+ gradientTime , gradient = measure (getGradient , simParams )
192193
193194 if i >= warmup :
194- totalForwardTime += forwardOnly
195+ totalForwardTime += forwardOnlyTime
195196 totalGradientTime += gradientTime
196197
197198
You can’t perform that action at this time.
0 commit comments