|
16 | 16 | import uuid |
17 | 17 | import pytest |
18 | 18 | from unittest import mock |
| 19 | +import random |
| 20 | +import string |
19 | 21 | from dataclasses import asdict, dataclass |
20 | 22 |
|
21 | 23 | from flytekit import LaunchPlan, kwtypes, WorkflowExecutionPhase, task, workflow |
22 | 24 | from flytekit.configuration import Config, ImageConfig, SerializationSettings |
23 | 25 | from flytekit.core.launch_plan import reference_launch_plan |
24 | 26 | from flytekit.core.task import reference_task |
25 | 27 | from flytekit.core.workflow import reference_workflow |
| 28 | +from flytekit.models import task as task_models |
26 | 29 | from flytekit.exceptions.user import FlyteAssertion, FlyteEntityNotExistException |
27 | 30 | from flytekit.extras.sqlite3.task import SQLite3Config, SQLite3Task |
28 | 31 | from flytekit.remote.remote import FlyteRemote |
@@ -1252,3 +1255,106 @@ def test_register_wf_twice(register): |
1252 | 1255 | ] |
1253 | 1256 | ) |
1254 | 1257 | assert out.returncode == 0 |
| 1258 | + |
| 1259 | + |
| 1260 | +def test_register_wf_with_resource_requests_override(register): |
| 1261 | + # Save the version here to retrieve the created task later |
| 1262 | + version = str(uuid.uuid4()) |
| 1263 | + |
| 1264 | + cpu = "1300m" |
| 1265 | + mem = "1100Mi" |
| 1266 | + |
| 1267 | + # Register the workflow with overridden default resources |
| 1268 | + out = subprocess.run( |
| 1269 | + [ |
| 1270 | + "pyflyte", |
| 1271 | + "--verbose", |
| 1272 | + "-c", |
| 1273 | + CONFIG, |
| 1274 | + "register", |
| 1275 | + "--resource-requests", |
| 1276 | + f"cpu={cpu},mem={mem}", |
| 1277 | + "--image", |
| 1278 | + IMAGE, |
| 1279 | + "--project", |
| 1280 | + PROJECT, |
| 1281 | + "--domain", |
| 1282 | + DOMAIN, |
| 1283 | + "--version", |
| 1284 | + version, |
| 1285 | + MODULE_PATH / "hello_world.py", |
| 1286 | + ] |
| 1287 | + ) |
| 1288 | + assert out.returncode == 0 |
| 1289 | + |
| 1290 | + # Retrieve the created task |
| 1291 | + remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) |
| 1292 | + task = remote.fetch_task(name="basic.hello_world.say_hello", version=version) |
| 1293 | + assert task.template.container is not None |
| 1294 | + assert task.template.container.resources == task_models.Resources( |
| 1295 | + requests=[ |
| 1296 | + task_models.Resources.ResourceEntry( |
| 1297 | + name=task_models.Resources.ResourceName.CPU, |
| 1298 | + value=cpu, |
| 1299 | + ), |
| 1300 | + task_models.Resources.ResourceEntry( |
| 1301 | + name=task_models.Resources.ResourceName.MEMORY, |
| 1302 | + value=mem, |
| 1303 | + ), |
| 1304 | + ], |
| 1305 | + limits=[], |
| 1306 | + ) |
| 1307 | + |
| 1308 | + |
| 1309 | +def test_run_wf_with_resource_requests_override(register): |
| 1310 | + # Save the execution id here to retrieve the created execution later |
| 1311 | + prefix = random.choice(string.ascii_lowercase) |
| 1312 | + short_random_part = uuid.uuid4().hex[:8] |
| 1313 | + execution_id = f"{prefix}{short_random_part}" |
| 1314 | + |
| 1315 | + cpu = "500m" |
| 1316 | + mem = "1Gi" |
| 1317 | + |
| 1318 | + # Register the workflow with overridden default resources |
| 1319 | + out = subprocess.run( |
| 1320 | + [ |
| 1321 | + "pyflyte", |
| 1322 | + "--verbose", |
| 1323 | + "-c", |
| 1324 | + CONFIG, |
| 1325 | + "run", |
| 1326 | + "--remote", |
| 1327 | + "--resource-requests", |
| 1328 | + f"cpu={cpu},mem={mem}", |
| 1329 | + "--project", |
| 1330 | + PROJECT, |
| 1331 | + "--domain", |
| 1332 | + DOMAIN, |
| 1333 | + "--name", |
| 1334 | + execution_id, |
| 1335 | + MODULE_PATH / "hello_world.py", |
| 1336 | + "my_wf" |
| 1337 | + ] |
| 1338 | + ) |
| 1339 | + assert out.returncode == 0 |
| 1340 | + |
| 1341 | + # Retrieve the created task |
| 1342 | + remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) |
| 1343 | + execution = remote.fetch_execution(name=execution_id) |
| 1344 | + execution = remote.wait(execution=execution) |
| 1345 | + version = execution.spec.launch_plan.version |
| 1346 | + task = remote.fetch_task(name="basic.hello_world.say_hello", version=version) |
| 1347 | + assert task.template.container is not None |
| 1348 | + assert task.template.container.resources == task_models.Resources( |
| 1349 | + requests=[ |
| 1350 | + task_models.Resources.ResourceEntry( |
| 1351 | + name=task_models.Resources.ResourceName.CPU, |
| 1352 | + value=cpu, |
| 1353 | + ), |
| 1354 | + task_models.Resources.ResourceEntry( |
| 1355 | + name=task_models.Resources.ResourceName.MEMORY, |
| 1356 | + value=mem, |
| 1357 | + ), |
| 1358 | + ], |
| 1359 | + limits=[], |
| 1360 | + ) |
0 commit comments