diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java index 15f527475bc..afa8fca3233 100644 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java @@ -44,6 +44,7 @@ protected void onCreate(Bundle savedInstanceState) { .get(); int numIter = intent.getIntExtra("num_iter", 50); + int numWarmupIter = intent.getIntExtra("num_warm_up_iter", 5); // TODO: Format the string with a parsable format Stats stats = new Stats(); @@ -58,6 +59,10 @@ protected Void doInBackground(Void... voids) { stats.errorCode = module.loadMethod("forward"); stats.loadEnd = System.nanoTime(); + for (int i = 0; i < numWarmupIter; i++) { + module.forward(); + } + for (int i = 0; i < numIter; i++) { long start = System.nanoTime(); module.forward();