Skip to content

Commit 0918c3d

Browse files
authored
fix: fix gt4py metrics extractor in the StencilTest benchmarking (#1111)
Get the metrics key using new gt4py hooks.
1 parent d0c3615 commit 0918c3d

File tree

1 file changed

+46
-16
lines changed

1 file changed

+46
-16
lines changed

model/testing/src/icon4py/model/testing/stencil_tests.py

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,25 @@
88

99
from __future__ import annotations
1010

11+
import contextlib
1112
import dataclasses
1213
import os
13-
from collections.abc import Callable, Mapping, Sequence
14-
from typing import Any, ClassVar
14+
from collections.abc import Callable, Generator, Mapping, Sequence
15+
from typing import Any, ClassVar, Final
1516

1617
import gt4py.next as gtx
1718
import numpy as np
1819
import pytest
1920
from gt4py import eve
2021
from gt4py.next import (
21-
config as gtx_config,
2222
constructors,
2323
named_collections as gtx_named_collections,
2424
typing as gtx_typing,
2525
)
2626

2727
# TODO(havogt): import will disappear after FieldOperators support `.compile`
2828
from gt4py.next.ffront.decorator import FieldOperator
29-
from gt4py.next.instrumentation import metrics as gtx_metrics
29+
from gt4py.next.instrumentation import hooks as gtx_hooks, metrics as gtx_metrics
3030

3131
from icon4py.model.common import model_backends, model_options
3232
from icon4py.model.common.grid import base
@@ -91,6 +91,8 @@ def test_and_benchmark(
9191
skip_stenciltest_verification = request.config.getoption(
9292
"skip_stenciltest_verification"
9393
) # skip verification if `--skip-stenciltest-verification` CLI option is set
94+
skip_stenciltest_benchmark = benchmark is None or not benchmark.enabled
95+
9496
if not skip_stenciltest_verification:
9597
reference_outputs = self.reference(
9698
_ConnectivityConceptFixer(
@@ -107,7 +109,7 @@ def test_and_benchmark(
107109
input_data=_properly_allocated_input_data, reference_outputs=reference_outputs
108110
)
109111

110-
if benchmark is not None and benchmark.enabled:
112+
if not skip_stenciltest_benchmark:
111113
warmup_rounds = int(os.getenv("ICON4PY_STENCIL_TEST_WARMUP_ROUNDS", "1"))
112114
iterations = int(os.getenv("ICON4PY_STENCIL_TEST_ITERATIONS", "10"))
113115

@@ -124,24 +126,52 @@ def test_and_benchmark(
124126
)
125127

126128
# Collect GT4Py runtime metrics if enabled
127-
if gtx_config.COLLECT_METRICS_LEVEL > 0:
129+
if gtx_metrics.is_any_level_enabled():
130+
metrics_key = None
131+
# Run the program one final time to get the metrics key
132+
METRICS_KEY_EXTRACTOR: Final = "metrics_id_extractor"
133+
134+
@contextlib.contextmanager
135+
def _get_metrics_id_program_callback(
136+
program: gtx_typing.Program,
137+
args: tuple[Any, ...],
138+
offset_provider: gtx.common.OffsetProvider,
139+
enable_jit: bool,
140+
kwargs: dict[str, Any],
141+
) -> Generator[None, None, None]:
142+
yield
143+
# Collect the key after running the program to make sure it is set
144+
nonlocal metrics_key
145+
metrics_key = gtx_metrics.get_current_source_key()
146+
147+
gtx_hooks.program_call_context.register(
148+
_get_metrics_id_program_callback, name=METRICS_KEY_EXTRACTOR
149+
)
150+
_configured_program(
151+
**_properly_allocated_input_data, offset_provider=grid.connectivities
152+
)
153+
gtx_hooks.program_call_context.remove(METRICS_KEY_EXTRACTOR)
154+
assert metrics_key is not None, "Metrics key could not be recovered during run."
155+
assert metrics_key.startswith(
156+
_configured_program.__name__
157+
), f"Metrics key ({metrics_key}) does not start with the program name ({_configured_program.__name__})"
158+
128159
assert (
129160
len(_configured_program._compiled_programs.compiled_programs) == 1
130161
), "Multiple compiled programs found, cannot extract metrics."
131-
# Get compiled programs from the _configured_program passed to test
132-
compiled_programs = _configured_program._compiled_programs.compiled_programs
133-
# Get the pool key necessary to find the right metrics key. There should be only one compiled program in _configured_program
134-
pool_key = next(iter(compiled_programs.keys()))
135-
# Get the metrics key from the pool key to read the corresponding metrics
136-
metrics_key = _configured_program._compiled_programs._metrics_key_from_pool_key(
137-
pool_key
138-
)
139162
metrics_data = gtx_metrics.sources
140163
compute_samples = metrics_data[metrics_key].metrics["compute"].samples
141-
# exclude warmup iterations, one extra iteration for calibrating pytest-benchmark and one for validation (if executed)
164+
# exclude:
165+
# - one for validation (if executed)
166+
# - one extra warmup round for calibrating pytest-benchmark
167+
# - warmup iterations
168+
# - one last round to get the metrics key
142169
initial_program_iterations_to_skip = warmup_rounds * iterations + (
143-
1 if skip_stenciltest_verification else 2
170+
2 if skip_stenciltest_verification else 3
144171
)
172+
assert (
173+
len(compute_samples) > initial_program_iterations_to_skip
174+
), "Not enough samples collected to compute metrics."
145175
benchmark.extra_info["gtx_metrics"] = compute_samples[
146176
initial_program_iterations_to_skip:
147177
]

0 commit comments

Comments
 (0)