Skip to content

Commit 8b4ae84

Browse files
committed
switch usage of torch.autograd.functional.vjp to torch.autograd.grad (30% faster)
They give the same numeric result, but torch.autograd.grad is a more standard usage, and performance is better for this use case. torch.autograd.functional.vjp is unnecessary because there's no custom input to the pullback, and it was implicitly calling the pullback with the normal unit vector, which is exactly what torch.autograd.grad does anyway
1 parent 59d9860 commit 8b4ae84

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

Benchmarks/BuildingSimulation/PyTorch/PyTorchSimulator.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def measure(function, arguments):
162162
start = time.time()
163163
result = function(arguments)
164164
end = time.time()
165-
return end - start
165+
return (end - start, result)
166166

167167

168168
def fullPipe(simParams):
@@ -180,18 +180,19 @@ def fullPipe(simParams):
180180

181181
for i in range(trials):
182182

183-
forwardOnly = measure(fullPipe, SimParamsConstant)
183+
inputs = SimParamsConstant
184+
forwardOnlyTime, forwardOutput = measure(fullPipe, inputs)
184185

185186
simParams = SimParamsConstant
186187
def getGradient(simParams):
187-
endTemperature, gradient = torch.autograd.functional.vjp(simulate, SimParamsConstant)
188+
gradient = torch.autograd.grad(forwardOutput, inputs)
188189
return gradient
189190

190191

191-
gradientTime = measure(getGradient, simParams)
192+
gradientTime, gradient = measure(getGradient, simParams)
192193

193194
if i >= warmup:
194-
totalForwardTime += forwardOnly
195+
totalForwardTime += forwardOnlyTime
195196
totalGradientTime += gradientTime
196197

197198

0 commit comments

Comments
 (0)