@@ -405,3 +405,179 @@ def test_array_assertion(self):
405405
406406 with pytest .raises (AssertionError , match = "array is not supported" ):
407407 request .materialize ()
408+
409+ def test_command_groups_env_vars (self ):
410+ """Test environment variables are properly set for each command group."""
411+ # Create executor with environment variables
412+ executor = SlurmExecutor (
413+ account = "test_account" ,
414+ env_vars = {"GLOBAL_ENV" : "global_value" },
415+ )
416+ executor .run_as_group = True
417+
418+ # Create resource groups with different env vars
419+ resource_group = [
420+ SlurmExecutor .ResourceRequest (
421+ packager = Mock (),
422+ nodes = 1 ,
423+ ntasks_per_node = 1 ,
424+ container_image = "image1" ,
425+ env_vars = {"GROUP1_ENV" : "group1_value" },
426+ container_mounts = ["/mount1" ],
427+ ),
428+ SlurmExecutor .ResourceRequest (
429+ packager = Mock (),
430+ nodes = 1 ,
431+ ntasks_per_node = 1 ,
432+ container_image = "image2" ,
433+ env_vars = {"GROUP2_ENV" : "group2_value" },
434+ container_mounts = ["/mount2" ],
435+ ),
436+ ]
437+ executor .resource_group = resource_group
438+ executor .tunnel = Mock (spec = SSHTunnel )
439+ executor .tunnel .job_dir = "/tmp/test_jobs"
440+
441+ request = SlurmRayRequest (
442+ name = "test-ray-cluster" ,
443+ cluster_dir = "/tmp/test_jobs/test-ray-cluster" ,
444+ template_name = "ray.sub.j2" ,
445+ executor = executor ,
446+ command_groups = [["cmd0" ], ["cmd1" ], ["cmd2" ]],
447+ launch_cmd = ["sbatch" , "--parsable" ],
448+ )
449+
450+ script = request .materialize ()
451+
452+ # Check global env vars are set in setup section
453+ assert "export GLOBAL_ENV=global_value" in script
454+
455+ # Check that command groups generate srun commands (excluding the first one)
456+ # The template should have a section for srun_commands
457+ assert "# Run extra commands" in script
458+ assert "srun" in script
459+ assert "cmd1" in script # First command group after skipping index 0
460+ assert "cmd2" in script # Second command group
461+
462+ def test_command_groups_without_resource_group (self ):
463+ """Test command groups work without resource groups."""
464+ executor = SlurmExecutor (
465+ account = "test_account" ,
466+ env_vars = {"GLOBAL_ENV" : "global_value" },
467+ )
468+ executor .tunnel = Mock (spec = SSHTunnel )
469+ executor .tunnel .job_dir = "/tmp/test_jobs"
470+
471+ request = SlurmRayRequest (
472+ name = "test-ray-cluster" ,
473+ cluster_dir = "/tmp/test_jobs/test-ray-cluster" ,
474+ template_name = "ray.sub.j2" ,
475+ executor = executor ,
476+ command_groups = [["cmd0" ], ["cmd1" ]],
477+ launch_cmd = ["sbatch" , "--parsable" ],
478+ )
479+
480+ script = request .materialize ()
481+
482+ # Should have global env vars
483+ assert "export GLOBAL_ENV=global_value" in script
484+
485+ # Should have srun commands for overlapping groups (skipping first)
486+ assert "srun" in script
487+ assert "--overlap" in script
488+ assert "cmd1" in script # Second command in the list (index 1)
489+
490+ def test_env_vars_formatting (self ):
491+ """Test that environment variables are properly formatted as export statements."""
492+ executor = SlurmExecutor (
493+ account = "test_account" ,
494+ env_vars = {
495+ "VAR_WITH_SPACES" : "value with spaces" ,
496+ "PATH_VAR" : "/usr/bin:/usr/local/bin" ,
497+ "EMPTY_VAR" : "" ,
498+ "NUMBER_VAR" : "123" ,
499+ },
500+ )
501+ executor .tunnel = Mock (spec = SSHTunnel )
502+ executor .tunnel .job_dir = "/tmp/test_jobs"
503+
504+ request = SlurmRayRequest (
505+ name = "test-ray-cluster" ,
506+ cluster_dir = "/tmp/test_jobs/test-ray-cluster" ,
507+ template_name = "ray.sub.j2" ,
508+ executor = executor ,
509+ launch_cmd = ["sbatch" , "--parsable" ],
510+ )
511+
512+ script = request .materialize ()
513+
514+ # Check all environment variables are properly exported
515+ assert "export VAR_WITH_SPACES=value with spaces" in script
516+ assert "export PATH_VAR=/usr/bin:/usr/local/bin" in script
517+ assert "export EMPTY_VAR=" in script
518+ assert "export NUMBER_VAR=123" in script
519+
520+ def test_group_env_vars_integration (self ):
521+ """Test full integration of group environment variables matching the artifact pattern."""
522+ # This test verifies the behavior seen in group_resource_req_slurm.sh
523+ executor = SlurmExecutor (
524+ account = "your_account" ,
525+ partition = "your_partition" ,
526+ time = "00:30:00" ,
527+ nodes = 1 ,
528+ ntasks_per_node = 8 ,
529+ gpus_per_node = 8 ,
530+ container_image = "some-image" ,
531+ container_mounts = ["/some/job/dir/sample_job:/nemo_run" ],
532+ env_vars = {"ENV_VAR" : "value" },
533+ )
534+ executor .run_as_group = True
535+
536+ # Set up resource groups with specific env vars
537+ resource_group = [
538+ # First group (index 0) - for the head/main command
539+ SlurmExecutor .ResourceRequest (
540+ packager = Mock (),
541+ nodes = 1 ,
542+ ntasks_per_node = 8 ,
543+ container_image = "some-image" ,
544+ env_vars = {"CUSTOM_ENV_1" : "some_value_1" },
545+ container_mounts = ["/some/job/dir/sample_job:/nemo_run" ],
546+ ),
547+ # Second group (index 1)
548+ SlurmExecutor .ResourceRequest (
549+ packager = Mock (),
550+ nodes = 1 ,
551+ ntasks_per_node = 8 ,
552+ container_image = "different_container_image" ,
553+ env_vars = {"CUSTOM_ENV_1" : "some_value_1" },
554+ container_mounts = ["/some/job/dir/sample_job:/nemo_run" ],
555+ ),
556+ ]
557+ executor .resource_group = resource_group
558+
559+ # Mock tunnel
560+ tunnel_mock = Mock (spec = SSHTunnel )
561+ tunnel_mock .job_dir = "/some/job/dir"
562+ executor .tunnel = tunnel_mock
563+
564+ request = SlurmRayRequest (
565+ name = "sample_job" ,
566+ cluster_dir = "/some/job/dir/sample_job" ,
567+ template_name = "ray.sub.j2" ,
568+ executor = executor ,
569+ command_groups = [
570+ ["bash ./scripts/start_server.sh" ],
571+ ["bash ./scripts/echo.sh server_host=$het_group_host_0" ],
572+ ],
573+ launch_cmd = ["sbatch" , "--parsable" ],
574+ )
575+
576+ script = request .materialize ()
577+
578+ # Verify the pattern matches the artifact:
579+ # 1. Global env vars should be exported in setup
580+ assert "export ENV_VAR=value" in script
581+
582+ # The template should include group_env_vars for proper env var handling per command
583+ # (The actual env var exports per command happen in the template rendering)
0 commit comments