Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 39 additions & 13 deletions model/testing/src/icon4py/model/testing/stencil_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,25 @@

from __future__ import annotations

import contextlib
import dataclasses
import os
from collections.abc import Callable, Mapping, Sequence
from typing import Any, ClassVar
from collections.abc import Callable, Generator, Mapping, Sequence
from typing import Any, ClassVar, Final

import gt4py.next as gtx
import numpy as np
import pytest
from gt4py import eve
from gt4py.next import (
config as gtx_config,
constructors,
named_collections as gtx_named_collections,
typing as gtx_typing,
)

# TODO(havogt): import will disappear after FieldOperators support `.compile`
from gt4py.next.ffront.decorator import FieldOperator
from gt4py.next.instrumentation import metrics as gtx_metrics
from gt4py.next.instrumentation import hooks as gtx_hooks, metrics as gtx_metrics

from icon4py.model.common import model_backends, model_options
from icon4py.model.common.grid import base
Expand Down Expand Up @@ -88,9 +88,40 @@ def test_and_benchmark(
_configured_program: Callable[..., None],
request: pytest.FixtureRequest,
) -> None:
metrics_key = None

skip_stenciltest_verification = request.config.getoption(
"skip_stenciltest_verification"
) # skip verification if `--skip-stenciltest-verification` CLI option is set
skip_stenciltest_benchmark = benchmark is not None and benchmark.enabled

if not (skip_stenciltest_verification and skip_stenciltest_benchmark):
# Precompile and run the program once to get the metrics key
METRICS_KEY_EXTRACTOR: Final = "metrics_id_extractor"
if gtx_metrics.is_any_level_enabled():

@contextlib.contextmanager
def _get_metrics_id_program_callback(
program: gtx_typing.Program,
args: tuple[Any, ...],
offset_provider: gtx.common.OffsetProvider,
enable_jit: bool,
kwargs: dict[str, Any],
) -> Generator[None, None, None]:
yield # run the program
nonlocal metrics_key
metrics_key = gtx_metrics.get_current_source_key()

gtx_hooks.program_call_context.register(
_get_metrics_id_program_callback, name=METRICS_KEY_EXTRACTOR
)

_configured_program(**_properly_allocated_input_data, offset_provider=grid.connectivities)
if gtx_metrics.is_any_level_enabled():
gtx_hooks.program_call_context.remove(METRICS_KEY_EXTRACTOR)

assert metrics_key is not None or not gtx_metrics.is_any_level_enabled()

if not skip_stenciltest_verification:
reference_outputs = self.reference(
_ConnectivityConceptFixer(
Expand Down Expand Up @@ -124,18 +155,13 @@ def test_and_benchmark(
)

# Collect GT4Py runtime metrics if enabled
if gtx_config.COLLECT_METRICS_LEVEL > 0:
if gtx_metrics.is_any_level_enabled():
assert (
metrics_key is not None
), "Metrics key should have been set during the warmup run."
assert (
len(_configured_program._compiled_programs.compiled_programs) == 1
), "Multiple compiled programs found, cannot extract metrics."
# Get compiled programs from the _configured_program passed to test
compiled_programs = _configured_program._compiled_programs.compiled_programs
# Get the pool key necessary to find the right metrics key. There should be only one compiled program in _configured_program
pool_key = next(iter(compiled_programs.keys()))
# Get the metrics key from the pool key to read the corresponding metrics
metrics_key = _configured_program._compiled_programs._metrics_key_from_pool_key(
pool_key
)
metrics_data = gtx_metrics.sources
compute_samples = metrics_data[metrics_key].metrics["compute"].samples
# exclude warmup iterations, one extra iteration for calibrating pytest-benchmark and one for validation (if executed)
Expand Down
Loading