Skip to content

Commit 6584149

Browse files
committed
standardize measure function
1 parent d638422 commit 6584149

File tree

4 files changed

+19
-17
lines changed

4 files changed

+19
-17
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ DerivedData
99
.build/
1010
.swift-version
1111
*.swp
12+
main

Benchmarks/BuildingSimulation/PyTorch/PyTorchSimulator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def fullPipe(simParams):
181181
for i in range(trials):
182182

183183
inputs = SimParamsConstant
184-
forwardOnlyTime, forwardOutput = measure(fullPipe, inputs)
184+
forwardTime, forwardOutput = measure(fullPipe, inputs)
185185

186186
simParams = SimParamsConstant
187187
def getGradient(simParams):
@@ -192,7 +192,7 @@ def getGradient(simParams):
192192
gradientTime, gradient = measure(getGradient, simParams)
193193

194194
if i >= warmup:
195-
totalForwardTime += forwardOnlyTime
195+
totalForwardTime += forwardTime
196196
totalGradientTime += gradientTime
197197

198198

Benchmarks/BuildingSimulation/Swift/main.swift

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ func simulate(simParams: SimParams) -> Float {
178178
var quanta = simParams.quanta
179179

180180
slab.temp = simParams.startingTemp
181-
for i in 0 ..< timesteps {
181+
for _ in 0 ..< timesteps {
182182
let tankAndQuanta = updateSourceTank(store: tank, quanta: quanta)
183183
tank = tankAndQuanta.tank
184184
quanta = tankAndQuanta.quanta
@@ -201,12 +201,12 @@ func dontLetTheCompilerOptimizeThisAway<T>(_ x: T) {
201201
blackHole = x
202202
}
203203

204-
func measure(_ block: () throws -> Void) -> Double {
204+
func measure<T>(_ block: () throws -> T) throws -> (time: Double, result: T) {
205205
let t0 = DispatchTime.now()
206-
try! block()
206+
let result = try block()
207207
let t1 = DispatchTime.now()
208208
let elapsed = Double(t1.uptimeNanoseconds - t0.uptimeNanoseconds) / 1E9
209-
return elapsed
209+
return (elapsed, result)
210210
}
211211

212212
@differentiable(reverse)
@@ -222,15 +222,16 @@ var totalPureForwardTime: Double = 0
222222
var totalGradientTime: Double = 0
223223

224224
for _ in 0 ..< trials {
225-
let forwardOnly = measure {
226-
let output = fullPipe(simParams: simParams)
227-
dontLetTheCompilerOptimizeThisAway(output)
225+
let (forwardOnly, _) = try measure {
226+
return fullPipe(simParams: simParams)
228227
}
228+
dontLetTheCompilerOptimizeThisAway(forwardOnly)
229229

230-
var grad: SimParams.TangentVector?
230+
let (gradientTime, grad) = try measure {
231+
return gradient(at: simParams, of: fullPipe)
232+
}
233+
dontLetTheCompilerOptimizeThisAway(grad)
231234

232-
let gradientTime = measure {
233-
grad = gradient(at: simParams, of: fullPipe)
234235
}
235236

236237
totalPureForwardTime += forwardOnly

Benchmarks/BuildingSimulation/TensorFlow/TensorFlowSimulator.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def measure(function, arguments):
211211
start = time.time()
212212
result = function(arguments)
213213
end = time.time()
214-
return end - start
214+
return (end - start, result)
215215

216216

217217
@tf.function
@@ -230,7 +230,7 @@ def fullPipe(simParams):
230230

231231
for i in range(trials + warmup):
232232

233-
forwardOnly = measure(fullPipe, SimParamsConstant)
233+
forwardTime, forwardOutput = measure(fullPipe, SimParamsConstant)
234234

235235
simParams = tf.Variable(SimParamsConstant)
236236
def getGradient(simParams):
@@ -241,10 +241,10 @@ def getGradient(simParams):
241241
return gradient
242242

243243

244-
gradientTime = measure(getGradient, simParams)
245-
244+
gradientTime, gradient = measure(getGradient, simParams)
245+
246246
if i >= warmup:
247-
totalForwardTime += forwardOnly
247+
totalForwardTime += forwardTime
248248
totalGradientTime += gradientTime
249249

250250

0 commit comments

Comments
 (0)