Skip to content

Commit 83f04f8

Browse files
Use newer SDK parameters to set Ray head requests and limits
1 parent 6ecda2d commit 83f04f8

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

examples/ray-finetune-llm-deepspeed/ray_finetune_llm_deepspeed.ipynb

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,12 @@
6363
" num_workers=7,\n",
6464
" worker_cpu_requests=16,\n",
6565
" worker_cpu_limits=16,\n",
66-
" head_cpus=16,\n",
66+
" head_cpu_requests=16,\n",
67+
" head_cpu_limits=16,\n",
6768
" worker_memory_requests=128,\n",
6869
" worker_memory_limits=256,\n",
69-
" head_memory=128,\n",
70+
" head_memory_requests=128,\n",
71+
" head_memory_limits=256,\n",
7072
" # Use the following parameters with NVIDIA GPUs\n",
7173
" image=\"quay.io/rhoai/ray:2.35.0-py39-cu121-torch24-fa26\",\n",
7274
" head_extended_resource_requests={'nvidia.com/gpu':1},\n",

tests/odh/ray_finetune_llm_deepspeed_test.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,16 @@ func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int, modelName string, modelC
5959
"token = ''": fmt.Sprintf("token = '%s'", userToken),
6060
"server = ''": fmt.Sprintf("server = '%s'", GetOpenShiftApiUrl(test)),
6161
"namespace='ray-finetune-llm-deepspeed'": fmt.Sprintf("namespace='%s'", namespace.Name),
62-
"head_cpus=16": "head_cpus=2",
62+
"head_cpu_requests=16": "head_cpu_requests=2",
63+
"head_cpu_limits=16": "head_cpu_limits=2",
6364
"head_extended_resource_requests=1": "head_extended_resource_requests=0",
6465
"num_workers=7": "num_workers=1",
6566
"worker_cpu_requests=16": "worker_cpu_requests=4",
6667
"worker_cpu_limits=16": "worker_cpu_limits=4",
6768
"worker_memory_requests=128": "worker_memory_requests=64",
6869
"worker_memory_limits=256": "worker_memory_limits=128",
69-
"head_memory=128": "head_memory=48",
70+
"head_memory_requests=128": "head_memory_requests=48",
71+
"head_memory_limits=256": "head_memory_limits=48",
7072
"client = cluster.job_client": "ray_dashboard = cluster.cluster_dashboard_uri()\\n\",\n\t\"header = {\\\"Authorization\\\": \\\"Bearer " + userToken + "\\\"}\\n\",\n\t\"client = RayJobClient(address=ray_dashboard, headers=header, verify=False)\\n",
7173
"--num-devices=8": fmt.Sprintf("--num-devices=%d", numGpus),
7274
"--num-epochs=3": fmt.Sprintf("--num-epochs=%d", 1),

0 commit comments

Comments
 (0)