Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 146 additions & 16 deletions dvc/parsing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from dvc.exceptions import DvcException
from dvc.log import logger
from dvc.parsing.interpolate import ParseError
from dvc.parsing.interpolate import ParseError, get_matches
from dvc.utils.objects import cached_property

from .context import (
Expand Down Expand Up @@ -49,6 +49,9 @@
PLOTS_KWD = "plots"
STAGES_KWD = "stages"

# Reserved namespace for params interpolation (e.g., ${param.key})
PARAMS_NAMESPACE = "param"

FOREACH_KWD = "foreach"
MATRIX_KWD = "matrix"
DO_KWD = "do"
Expand Down Expand Up @@ -120,6 +123,27 @@ def func(s: "DictStrAny") -> None:
return recurse(func)(data)


def has_params_interpolation(value: Any) -> bool:
"""Check if value contains ${PARAMS_NAMESPACE.*} interpolation."""
prefix = f"{PARAMS_NAMESPACE}."

def check_str(s: str) -> bool:
matches = get_matches(s)
for match in matches:
inner = match["inner"]
if inner.startswith(prefix):
return True
return False

if isinstance(value, str):
return check_str(value)
if isinstance(value, (list, tuple)):
return any(has_params_interpolation(item) for item in value)
if isinstance(value, dict):
return any(has_params_interpolation(v) for v in value.values())
return False


Definition = Union["ForeachDefinition", "EntryDefinition", "MatrixDefinition"]


Expand Down Expand Up @@ -150,12 +174,24 @@ def __init__(self, repo: "Repo", wdir: str, d: dict):
check_interpolations(vars_, VARS_KWD, self.relpath)
self.context: Context = Context()

# Reserve namespace to prevent conflicts with
# ${PARAMS_NAMESPACE.*} interpolation - must be done before loading vars
self.context._reserved_keys[PARAMS_NAMESPACE] = True

try:
args = fs, vars_, wdir # load from `vars` section
self.context.load_from_vars(*args, default=DEFAULT_PARAMS_FILE)
except ContextError as exc:
format_and_raise(exc, "'vars'", self.relpath)

# Load global-level params for ${PARAMS_NAMESPACE.*} interpolation
global_params = d.get(PARAMS_KWD, [])
if global_params:
try:
self.context.load_params(fs, global_params, wdir)
except ContextError as exc:
format_and_raise(exc, f"'{PARAMS_KWD}'", self.relpath)

# we use `tracked_vars` to keep a dictionary of used variables
# by the interpolated entries.
self.tracked_vars: dict[str, Mapping] = {}
Expand Down Expand Up @@ -295,6 +331,82 @@ def resolve(self, **kwargs):
except ContextError as exc:
format_and_raise(exc, f"stage '{self.name}'", self.relpath)

def _load_stage_vars(self, context, definition, wdir, name):
"""Load stage-level vars into context."""
vars_ = definition.pop(VARS_KWD, [])
# TODO: Should `vars` be templatized?
check_interpolations(vars_, f"{self.where}.{name}.vars", self.relpath)

if vars_:
# Optimization: Lookahead if it has any vars, if it does not, we
# don't need to clone them.
context = Context.clone(context)

try:
fs = self.resolver.fs
context.load_from_vars(fs, vars_, wdir, stage_name=name)
except VarsAlreadyLoaded as exc:
format_and_raise(exc, f"'{self.where}.{name}.vars'", self.relpath)

return context, vars_

def _load_stage_params(self, context, definition, wdir, name, vars_):
"""Load stage-level params into context."""
# Load stage-level params for ${params.*} interpolation
# Note: params field is not popped, as it's needed for dependency tracking
stage_params = definition.get(PARAMS_KWD, [])
if not stage_params:
return context, stage_params

# Clone context if not already cloned
if not vars_:
context = Context.clone(context)

# Resolve interpolations in params field (e.g., ${item.model_type})
# This allows dynamic param file loading based on foreach/matrix values
resolved_params = context.resolve(
stage_params,
skip_interpolation_checks=True,
key=PARAMS_KWD,
config=self.resolver.parsing_config,
)

try:
fs = self.resolver.fs
context.load_params(fs, resolved_params, wdir)
except ContextError as exc:
format_and_raise(exc, f"'{self.where}.{name}.params'", self.relpath)

return context, stage_params

def _check_params_ambiguity(self, context, stage_params, tracked_data, name):
"""Check for ambiguous params keys in tracked data."""
if not (
stage_params
or (
hasattr(self.resolver.context, "_params_context")
and self.resolver.context._params_context
)
):
return

# Extract used params keys from tracked data
used_params_keys = set()
for source_file in tracked_data:
# Check if this source is a params file
if context._params_sources:
for key in tracked_data[source_file]:
# Extract top-level key
top_key = key.split(".")[0]
if top_key in context._params_sources:
used_params_keys.add(key)

if used_params_keys:
try:
context.check_params_ambiguity(used_params_keys)
except ContextError as exc:
format_and_raise(exc, f"stage '{name}'", self.relpath)

def resolve_stage(self, skip_checks: bool = False) -> "DictStrAny":
context = self.context
name = self.name
Expand All @@ -309,19 +421,11 @@ def resolve_stage(self, skip_checks: bool = False) -> "DictStrAny":
definition = deepcopy(self.definition)

wdir = self._resolve_wdir(context, name, definition.get(WDIR_KWD))
vars_ = definition.pop(VARS_KWD, [])
# FIXME: Should `vars` be templatized?
check_interpolations(vars_, f"{self.where}.{name}.vars", self.relpath)
if vars_:
# Optimization: Lookahead if it has any vars, if it does not, we
# don't need to clone them.
context = Context.clone(context)

try:
fs = self.resolver.fs
context.load_from_vars(fs, vars_, wdir, stage_name=name)
except VarsAlreadyLoaded as exc:
format_and_raise(exc, f"'{self.where}.{name}.vars'", self.relpath)
context, vars_ = self._load_stage_vars(context, definition, wdir, name)
context, stage_params = self._load_stage_params(
context, definition, wdir, name, vars_
)

logger.trace("Context during resolution of stage %s:\n%s", name, context)

Expand All @@ -337,12 +441,25 @@ def resolve_stage(self, skip_checks: bool = False) -> "DictStrAny":
for key, value in definition.items()
}

self._check_params_ambiguity(context, stage_params, tracked_data, name)
self.resolver.track_vars(name, tracked_data)
return {name: resolved}

def _resolve(
self, context: "Context", value: Any, key: str, skip_checks: bool
) -> "DictStrAny":
# Check if params interpolation is used in restricted fields
if has_params_interpolation(value):
allowed_fields = {"cmd"}
if key not in allowed_fields:
raise ResolveError(
f"failed to parse '{self.where}.{self.name}.{key}' "
f"in '{self.relpath}': "
f"${{{PARAMS_NAMESPACE}.*}} interpolation is not allowed "
f"in '{key}' field, "
f"only in: {', '.join(sorted(allowed_fields))}"
)

try:
return context.resolve(
value,
Expand Down Expand Up @@ -377,6 +494,7 @@ def __init__(
assert MATRIX_KWD not in definition
self.foreach_data = definition[FOREACH_KWD]
self._template = definition[DO_KWD]
self.params_list = definition.get(PARAMS_KWD, [])

self.pair = IterationPair()
self.where = where
Expand Down Expand Up @@ -487,9 +605,21 @@ def _each_iter(self, key: str) -> "DictStrAny":
# the no. of items to be generated which means more cloning,
# i.e. quadratic complexity).
generated = self._generate_name(key)
entry = EntryDefinition(
self.resolver, self.context, generated, self.template
)

# Load params if defined at foreach level
context = self.context
if self.params_list:
context = Context.clone(self.context)
try:
fs = self.resolver.fs
wdir = self.resolver.wdir
context.load_params(fs, self.params_list, wdir)
except ContextError as exc:
format_and_raise(
exc, f"'{self.where}.{self.name}.params'", self.relpath
)

entry = EntryDefinition(self.resolver, context, generated, self.template)
try:
# optimization: skip checking for syntax errors on each foreach
# generated stages. We do it once when accessing template.
Expand Down
Loading