Skip to content

Commit 2dcbc2e

Browse files
authored
fix: write extra_vars into a file when call ansible-runner (#1343)
Treat all str values in extra_vars as unsafe. Write them to a yaml file to be consumed by the ansible-runner.
1 parent 05598a1 commit 2dcbc2e

File tree

4 files changed

+109
-8
lines changed

4 files changed

+109
-8
lines changed
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2025 Red Hat, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import yaml
16+
17+
18+
class UnsafeString(str):
19+
pass
20+
21+
22+
def unsafe_string_representer(
23+
dumper: yaml.SafeDumper, obj: UnsafeString
24+
) -> yaml.ScalarNode:
25+
return dumper.represent_scalar("!unsafe", obj)
26+
27+
28+
def get_dumper() -> yaml.SafeDumper:
29+
"""Add representers to a YAML seriailizer."""
30+
safe_dumper = yaml.SafeDumper
31+
safe_dumper.add_representer(UnsafeString, unsafe_string_representer)
32+
return safe_dumper
33+
34+
35+
def dump(filename: str, data: dict) -> None:
36+
"""Write data to a file as YAML with unsafe strings."""
37+
if not isinstance(data, dict):
38+
raise TypeError("Data must be a dictionary")
39+
40+
def transform_strings(data):
41+
"""Recursively transform any string type to UnsafeString."""
42+
if isinstance(data, str):
43+
return UnsafeString(data)
44+
if isinstance(data, dict):
45+
return {k: transform_strings(v) for k, v in data.items()}
46+
if isinstance(data, list):
47+
return [transform_strings(v) for v in data]
48+
if isinstance(data, tuple):
49+
return tuple(transform_strings(v) for v in data)
50+
if isinstance(data, set):
51+
return {transform_strings(v) for v in data}
52+
return data
53+
54+
data = transform_strings(data)
55+
56+
with open(filename, "w") as f:
57+
yaml.dump(data, f, Dumper=get_dumper())

src/aap_eda/services/project/scm.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
from aap_eda.core.models import EdaCredential
3131
from aap_eda.core.types import StrPath
32-
from aap_eda.core.utils.credentials import inputs_from_store
32+
from aap_eda.core.utils import credentials, safe_yaml
3333

3434
logger = logging.getLogger(__name__)
3535

@@ -151,7 +151,9 @@ def clone(
151151
gpg_key_file = None
152152
gpg_home_dir = None
153153
if credential:
154-
inputs = inputs_from_store(credential.inputs.get_secret_value())
154+
inputs = credentials.inputs_from_store(
155+
credential.inputs.get_secret_value()
156+
)
155157
secret = inputs.get("password", "")
156158
key_data = inputs.get("ssh_key_data", "")
157159

@@ -171,7 +173,7 @@ def clone(
171173
key_password = inputs.get("ssh_key_unlock")
172174

173175
if gpg_credential:
174-
gpg_inputs = inputs_from_store(
176+
gpg_inputs = credentials.inputs_from_store(
175177
gpg_credential.inputs.get_secret_value()
176178
)
177179
gpg_key = gpg_inputs.get("gpg_public_key")
@@ -295,12 +297,15 @@ def __call__(
295297
env_vars: dict,
296298
):
297299
with tempfile.TemporaryDirectory(prefix="EDA_RUNNER") as data_dir:
300+
env_dir = os.path.join(data_dir, "env")
301+
os.makedirs(env_dir)
302+
safe_yaml.dump(os.path.join(env_dir, "extravars"), extra_vars)
303+
298304
outputs = io.StringIO()
299305
with contextlib.redirect_stdout(outputs):
300306
runner = ansible_runner.run(
301307
private_data_dir=data_dir,
302308
playbook=PLAYBOOK,
303-
extravars=extra_vars,
304309
envvars=env_vars,
305310
)
306311

@@ -355,7 +360,4 @@ def is_refspec_valid(refspec: str, is_branch: bool) -> bool:
355360
f"{result.returncode}"
356361
)
357362
return False
358-
if is_branch and all(arg in refspec for arg in ["{{", "}}", "lookup"]):
359-
logger.error(f"branch {refspec} is invalid")
360-
return False
361363
return True

tests/unit/test_scm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def test_is_git_url_valid(url: str, expected: bool):
245245
("refs/heads/branch1", False, True),
246246
("@{-1}", True, False),
247247
("branch1", True, True),
248-
("{{lookup('branch1')}}", True, False),
248+
("{{lookup('branch1')}}", True, True),
249249
("{{lookup('branch1')}}", False, False),
250250
],
251251
)

tests/unit/test_utils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import tempfile
1516
from importlib.metadata import PackageNotFoundError, version
1617
from unittest.mock import patch
1718

@@ -21,6 +22,7 @@
2122
from drf_spectacular.utils import OpenApiParameter
2223
from rest_framework import serializers
2324

25+
from aap_eda.core.utils import safe_yaml
2426
from aap_eda.core.utils.strings import extract_variables
2527
from aap_eda.utils import (
2628
get_package_version,
@@ -213,3 +215,43 @@ def test_identifier_with_invalid_first_character_raises():
213215
too_long = "1" + "a" * 62
214216
with pytest.raises(ValueError, match="invalid first character"):
215217
sanitize_postgres_identifier(too_long)
218+
219+
220+
TEST_YAML_DATA = {
221+
"dict": {"a": "b", "c": "d"},
222+
"list": ["a", "b", "c"],
223+
"tuple": ("a", "b", "c"),
224+
"set": {"a", "b", "c"},
225+
"num": 300,
226+
}
227+
228+
229+
TEST_YAML_OUTPUT = """dict:
230+
a: !unsafe 'b'
231+
c: !unsafe 'd'
232+
list:
233+
- !unsafe 'a'
234+
- !unsafe 'b'
235+
- !unsafe 'c'
236+
num: 300
237+
set: !!set
238+
!unsafe 'a': null
239+
!unsafe 'b': null
240+
!unsafe 'c': null
241+
tuple:
242+
- !unsafe 'a'
243+
- !unsafe 'b'
244+
- !unsafe 'c'
245+
"""
246+
247+
248+
def test_dump_safe_yaml():
249+
with tempfile.NamedTemporaryFile() as f:
250+
safe_yaml.dump(f.name, TEST_YAML_DATA)
251+
with open(f.name) as f:
252+
assert f.read() == TEST_YAML_OUTPUT
253+
254+
255+
def test_dump_safe_yaml_invalid_type():
256+
with pytest.raises(TypeError, match="Data must be a dictionary"):
257+
safe_yaml.dump("test", "test")

0 commit comments

Comments
 (0)