diff --git a/model/testing/src/icon4py/model/testing/stencil_tests.py b/model/testing/src/icon4py/model/testing/stencil_tests.py index 8bb0e1eeaa..d19a11ebdf 100644 --- a/model/testing/src/icon4py/model/testing/stencil_tests.py +++ b/model/testing/src/icon4py/model/testing/stencil_tests.py @@ -8,17 +8,17 @@ from __future__ import annotations +import contextlib import dataclasses import os -from collections.abc import Callable, Mapping, Sequence -from typing import Any, ClassVar +from collections.abc import Callable, Generator, Mapping, Sequence +from typing import Any, ClassVar, Final import gt4py.next as gtx import numpy as np import pytest from gt4py import eve from gt4py.next import ( - config as gtx_config, constructors, named_collections as gtx_named_collections, typing as gtx_typing, @@ -26,7 +26,7 @@ # TODO(havogt): import will disappear after FieldOperators support `.compile` from gt4py.next.ffront.decorator import FieldOperator -from gt4py.next.instrumentation import metrics as gtx_metrics +from gt4py.next.instrumentation import hooks as gtx_hooks, metrics as gtx_metrics from icon4py.model.common import model_backends, model_options from icon4py.model.common.grid import base @@ -91,6 +91,8 @@ def test_and_benchmark( skip_stenciltest_verification = request.config.getoption( "skip_stenciltest_verification" ) # skip verification if `--skip-stenciltest-verification` CLI option is set + skip_stenciltest_benchmark = benchmark is None or not benchmark.enabled + if not skip_stenciltest_verification: reference_outputs = self.reference( _ConnectivityConceptFixer( @@ -107,7 +109,7 @@ def test_and_benchmark( input_data=_properly_allocated_input_data, reference_outputs=reference_outputs ) - if benchmark is not None and benchmark.enabled: + if not skip_stenciltest_benchmark: warmup_rounds = int(os.getenv("ICON4PY_STENCIL_TEST_WARMUP_ROUNDS", "1")) iterations = int(os.getenv("ICON4PY_STENCIL_TEST_ITERATIONS", "10")) @@ -124,24 +126,52 @@ def test_and_benchmark( ) # Collect GT4Py runtime metrics if enabled - if gtx_config.COLLECT_METRICS_LEVEL > 0: + if gtx_metrics.is_any_level_enabled(): + metrics_key = None + # Run the program one final time to get the metrics key + METRICS_KEY_EXTRACTOR: Final = "metrics_id_extractor" + + @contextlib.contextmanager + def _get_metrics_id_program_callback( + program: gtx_typing.Program, + args: tuple[Any, ...], + offset_provider: gtx.common.OffsetProvider, + enable_jit: bool, + kwargs: dict[str, Any], + ) -> Generator[None, None, None]: + yield + # Collect the key after running the program to make sure it is set + nonlocal metrics_key + metrics_key = gtx_metrics.get_current_source_key() + + gtx_hooks.program_call_context.register( + _get_metrics_id_program_callback, name=METRICS_KEY_EXTRACTOR + ) + _configured_program( + **_properly_allocated_input_data, offset_provider=grid.connectivities + ) + gtx_hooks.program_call_context.remove(METRICS_KEY_EXTRACTOR) + assert metrics_key is not None, "Metrics key could not be recovered during run." + assert metrics_key.startswith( + _configured_program.__name__ + ), f"Metrics key ({metrics_key}) does not start with the program name ({_configured_program.__name__})" + assert ( len(_configured_program._compiled_programs.compiled_programs) == 1 ), "Multiple compiled programs found, cannot extract metrics." - # Get compiled programs from the _configured_program passed to test - compiled_programs = _configured_program._compiled_programs.compiled_programs - # Get the pool key necessary to find the right metrics key. There should be only one compiled program in _configured_program - pool_key = next(iter(compiled_programs.keys())) - # Get the metrics key from the pool key to read the corresponding metrics - metrics_key = _configured_program._compiled_programs._metrics_key_from_pool_key( - pool_key - ) metrics_data = gtx_metrics.sources compute_samples = metrics_data[metrics_key].metrics["compute"].samples - # exclude warmup iterations, one extra iteration for calibrating pytest-benchmark and one for validation (if executed) + # exclude: + # - one for validation (if executed) + # - one extra warmup round for calibrating pytest-benchmark + # - warmup iterations + # - one last round to get the metrics key initial_program_iterations_to_skip = warmup_rounds * iterations + ( - 1 if skip_stenciltest_verification else 2 + 2 if skip_stenciltest_verification else 3 ) + assert ( + len(compute_samples) > initial_program_iterations_to_skip + ), "Not enough samples collected to compute metrics." benchmark.extra_info["gtx_metrics"] = compute_samples[ initial_program_iterations_to_skip: ]