88
99from __future__ import annotations
1010
11+ import contextlib
1112import dataclasses
1213import os
13- from collections .abc import Callable , Mapping , Sequence
14- from typing import Any , ClassVar
14+ from collections .abc import Callable , Generator , Mapping , Sequence
15+ from typing import Any , ClassVar , Final
1516
1617import gt4py .next as gtx
1718import numpy as np
1819import pytest
1920from gt4py import eve
2021from gt4py .next import (
21- config as gtx_config ,
2222 constructors ,
2323 named_collections as gtx_named_collections ,
2424 typing as gtx_typing ,
2525)
2626
2727# TODO(havogt): import will disappear after FieldOperators support `.compile`
2828from gt4py .next .ffront .decorator import FieldOperator
29- from gt4py .next .instrumentation import metrics as gtx_metrics
29+ from gt4py .next .instrumentation import hooks as gtx_hooks , metrics as gtx_metrics
3030
3131from icon4py .model .common import model_backends , model_options
3232from icon4py .model .common .grid import base
@@ -91,6 +91,8 @@ def test_and_benchmark(
9191 skip_stenciltest_verification = request .config .getoption (
9292 "skip_stenciltest_verification"
9393 ) # skip verification if `--skip-stenciltest-verification` CLI option is set
94+ skip_stenciltest_benchmark = benchmark is None or not benchmark .enabled
95+
9496 if not skip_stenciltest_verification :
9597 reference_outputs = self .reference (
9698 _ConnectivityConceptFixer (
@@ -107,7 +109,7 @@ def test_and_benchmark(
107109 input_data = _properly_allocated_input_data , reference_outputs = reference_outputs
108110 )
109111
110- if benchmark is not None and benchmark . enabled :
112+ if not skip_stenciltest_benchmark :
111113 warmup_rounds = int (os .getenv ("ICON4PY_STENCIL_TEST_WARMUP_ROUNDS" , "1" ))
112114 iterations = int (os .getenv ("ICON4PY_STENCIL_TEST_ITERATIONS" , "10" ))
113115
@@ -124,24 +126,52 @@ def test_and_benchmark(
124126 )
125127
126128 # Collect GT4Py runtime metrics if enabled
127- if gtx_config .COLLECT_METRICS_LEVEL > 0 :
129+ if gtx_metrics .is_any_level_enabled ():
130+ metrics_key = None
131+ # Run the program one final time to get the metrics key
132+ METRICS_KEY_EXTRACTOR : Final = "metrics_id_extractor"
133+
134+ @contextlib .contextmanager
135+ def _get_metrics_id_program_callback (
136+ program : gtx_typing .Program ,
137+ args : tuple [Any , ...],
138+ offset_provider : gtx .common .OffsetProvider ,
139+ enable_jit : bool ,
140+ kwargs : dict [str , Any ],
141+ ) -> Generator [None , None , None ]:
142+ yield
143+ # Collect the key after running the program to make sure it is set
144+ nonlocal metrics_key
145+ metrics_key = gtx_metrics .get_current_source_key ()
146+
147+ gtx_hooks .program_call_context .register (
148+ _get_metrics_id_program_callback , name = METRICS_KEY_EXTRACTOR
149+ )
150+ _configured_program (
151+ ** _properly_allocated_input_data , offset_provider = grid .connectivities
152+ )
153+ gtx_hooks .program_call_context .remove (METRICS_KEY_EXTRACTOR )
154+ assert metrics_key is not None , "Metrics key could not be recovered during run."
155+ assert metrics_key .startswith (
156+ _configured_program .__name__
157+ ), f"Metrics key ({ metrics_key } ) does not start with the program name ({ _configured_program .__name__ } )"
158+
128159 assert (
129160 len (_configured_program ._compiled_programs .compiled_programs ) == 1
130161 ), "Multiple compiled programs found, cannot extract metrics."
131- # Get compiled programs from the _configured_program passed to test
132- compiled_programs = _configured_program ._compiled_programs .compiled_programs
133- # Get the pool key necessary to find the right metrics key. There should be only one compiled program in _configured_program
134- pool_key = next (iter (compiled_programs .keys ()))
135- # Get the metrics key from the pool key to read the corresponding metrics
136- metrics_key = _configured_program ._compiled_programs ._metrics_key_from_pool_key (
137- pool_key
138- )
139162 metrics_data = gtx_metrics .sources
140163 compute_samples = metrics_data [metrics_key ].metrics ["compute" ].samples
141- # exclude warmup iterations, one extra iteration for calibrating pytest-benchmark and one for validation (if executed)
164+ # exclude:
165+ # - one for validation (if executed)
166+ # - one extra warmup round for calibrating pytest-benchmark
167+ # - warmup iterations
168+ # - one last round to get the metrics key
142169 initial_program_iterations_to_skip = warmup_rounds * iterations + (
143- 1 if skip_stenciltest_verification else 2
170+ 2 if skip_stenciltest_verification else 3
144171 )
172+ assert (
173+ len (compute_samples ) > initial_program_iterations_to_skip
174+ ), "Not enough samples collected to compute metrics."
145175 benchmark .extra_info ["gtx_metrics" ] = compute_samples [
146176 initial_program_iterations_to_skip :
147177 ]
0 commit comments