|
20 | 20 | InferenceServer, |
21 | 21 | InferenceTransport, |
22 | 22 | MPTransport, |
| 23 | + RayTransport, |
23 | 24 | ThreadingTransport, |
24 | 25 | ) |
25 | 26 |
|
| 27 | +_has_ray = True |
| 28 | +try: |
| 29 | + import ray |
| 30 | +except ImportError: |
| 31 | + _has_ray = False |
| 32 | + |
26 | 33 |
|
27 | 34 | # ============================================================================= |
28 | 35 | # Helpers |
@@ -398,3 +405,88 @@ def bad_model(td): |
398 | 405 | td = TensorDict({"observation": torch.randn(4)}) |
399 | 406 | with pytest.raises(ValueError, match="mp model error"): |
400 | 407 | client(td) |
| 408 | + |
| 409 | + |
| 410 | +# ============================================================================= |
| 411 | +# Tests: RayTransport (Commit 4) |
| 412 | +# ============================================================================= |
| 413 | + |
| 414 | + |
| 415 | +@pytest.mark.skipif(not _has_ray, reason="ray not installed") |
| 416 | +class TestRayTransport: |
| 417 | + @classmethod |
| 418 | + def setup_class(cls): |
| 419 | + if not ray.is_initialized(): |
| 420 | + ray.init(num_cpus=4, ignore_reinit_error=True) |
| 421 | + |
| 422 | + def test_single_request(self): |
| 423 | + transport = RayTransport() |
| 424 | + client = transport.client() |
| 425 | + policy = _make_policy() |
| 426 | + with InferenceServer(policy, transport, max_batch_size=4): |
| 427 | + td = TensorDict({"observation": torch.randn(4)}) |
| 428 | + result = client(td) |
| 429 | + assert "action" in result.keys() |
| 430 | + assert result["action"].shape == (2,) |
| 431 | + |
| 432 | + def test_concurrent_clients(self): |
| 433 | + """Multiple clients submit concurrently from threads (simulating Ray actors).""" |
| 434 | + transport = RayTransport() |
| 435 | + policy = _make_policy() |
| 436 | + n_clients = 4 |
| 437 | + n_requests = 20 |
| 438 | + |
| 439 | + clients = [transport.client() for _ in range(n_clients)] |
| 440 | + results_per_client: list[list[TensorDictBase]] = [[] for _ in range(n_clients)] |
| 441 | + |
| 442 | + def client_fn(client_idx): |
| 443 | + for _ in range(n_requests): |
| 444 | + td = TensorDict({"observation": torch.randn(4)}) |
| 445 | + result = clients[client_idx](td) |
| 446 | + results_per_client[client_idx].append(result) |
| 447 | + |
| 448 | + with InferenceServer(policy, transport, max_batch_size=8): |
| 449 | + with concurrent.futures.ThreadPoolExecutor(max_workers=n_clients) as pool: |
| 450 | + futs = [pool.submit(client_fn, i) for i in range(n_clients)] |
| 451 | + concurrent.futures.wait(futs) |
| 452 | + for f in futs: |
| 453 | + f.result() |
| 454 | + |
| 455 | + for client_results in results_per_client: |
| 456 | + assert len(client_results) == n_requests |
| 457 | + for r in client_results: |
| 458 | + assert "action" in r.keys() |
| 459 | + assert r["action"].shape == (2,) |
| 460 | + |
| 461 | + def test_ray_remote_actor(self): |
| 462 | + """A Ray remote actor can use the client to get inference results.""" |
| 463 | + transport = RayTransport() |
| 464 | + client = transport.client() |
| 465 | + policy = _make_policy() |
| 466 | + |
| 467 | + @ray.remote |
| 468 | + def remote_actor_fn(client, n_requests): |
| 469 | + results = [] |
| 470 | + for _ in range(n_requests): |
| 471 | + td = TensorDict({"observation": torch.randn(4)}) |
| 472 | + result = client(td) |
| 473 | + results.append(result["action"].shape) |
| 474 | + return results |
| 475 | + |
| 476 | + with InferenceServer(policy, transport, max_batch_size=8): |
| 477 | + ref = remote_actor_fn.remote(client, 5) |
| 478 | + shapes = ray.get(ref, timeout=30.0) |
| 479 | + assert len(shapes) == 5 |
| 480 | + for s in shapes: |
| 481 | + assert s == (2,) |
| 482 | + |
| 483 | + def test_ray_exception_propagates(self): |
| 484 | + def bad_model(td): |
| 485 | + raise ValueError("ray model error") |
| 486 | + |
| 487 | + transport = RayTransport() |
| 488 | + client = transport.client() |
| 489 | + with InferenceServer(bad_model, transport, max_batch_size=4): |
| 490 | + td = TensorDict({"observation": torch.randn(4)}) |
| 491 | + with pytest.raises(ValueError, match="ray model error"): |
| 492 | + client(td) |
0 commit comments