Skip to content

Commit aac5bb1

Browse files
authored
Merge pull request #23 from porterchild/improve-pytorch-usage
switch usage of torch.autograd.functional.vjp to torch.autograd.grad …
2 parents 59d9860 + 8b4ae84 commit aac5bb1

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

Benchmarks/BuildingSimulation/PyTorch/PyTorchSimulator.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def measure(function, arguments):
162162
start = time.time()
163163
result = function(arguments)
164164
end = time.time()
165-
return end - start
165+
return (end - start, result)
166166

167167

168168
def fullPipe(simParams):
@@ -180,18 +180,19 @@ def fullPipe(simParams):
180180

181181
for i in range(trials):
182182

183-
forwardOnly = measure(fullPipe, SimParamsConstant)
183+
inputs = SimParamsConstant
184+
forwardOnlyTime, forwardOutput = measure(fullPipe, inputs)
184185

185186
simParams = SimParamsConstant
186187
def getGradient(simParams):
187-
endTemperature, gradient = torch.autograd.functional.vjp(simulate, SimParamsConstant)
188+
gradient = torch.autograd.grad(forwardOutput, inputs)
188189
return gradient
189190

190191

191-
gradientTime = measure(getGradient, simParams)
192+
gradientTime, gradient = measure(getGradient, simParams)
192193

193194
if i >= warmup:
194-
totalForwardTime += forwardOnly
195+
totalForwardTime += forwardOnlyTime
195196
totalGradientTime += gradientTime
196197

197198

0 commit comments

Comments
 (0)