Skip to content

Commit a07e261

Browse files
authored
Merge pull request #63 from atomscale-ai/enhancement/project_id_to_stream
Add `project_id` support to RHEED streamer
2 parents 566fcf9 + 25390d5 commit a07e261

File tree

5 files changed

+512
-10
lines changed

5 files changed

+512
-10
lines changed

src/atomscale/streaming/rheed_stream.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class RHEEDStreamer:
2525
chunk_size: int,
2626
stream_name: str | None = None,
2727
physical_sample: str | None = None,
28+
project_id: str | None = None,
2829
) -> str: ...
2930
def run(
3031
self,

src/atomscale/streaming/src/initialize.rs

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use anyhow::{Context, Result};
22
use reqwest::Client;
33
use serde::{Deserialize, Serialize};
4+
use serde_json::Value;
45

56
#[derive(Serialize, Debug)]
67
#[serde(rename_all = "snake_case")] // Ensures JSON fields are snake_case (e.g., data_id)
@@ -9,6 +10,20 @@ pub struct RHEEDStreamSettings {
910
pub rotational_period: f64,
1011
pub rotations_per_min: f64,
1112
pub fps_capture_rate: f64,
13+
#[serde(skip_serializing_if = "Option::is_none")]
14+
pub project_id: Option<String>,
15+
}
16+
17+
// Project-related structs for GET /projects/ response
18+
#[derive(Deserialize, Debug)]
19+
struct DataConfiguration {
20+
api_configuration: Option<Value>,
21+
}
22+
23+
#[derive(Deserialize, Debug)]
24+
struct ProjectSummary {
25+
id: String,
26+
configuration: Option<DataConfiguration>,
1227
}
1328

1429
/// POST request to initialize a RHEED stream
@@ -58,10 +73,10 @@ pub async fn ensure_physical_sample_link(
5873
api_key: &str,
5974
data_id: &str,
6075
sample_name: &str,
61-
) -> Result<()> {
76+
) -> Result<String> {
6277
let sample_name = sample_name.trim();
6378
if sample_name.is_empty() {
64-
return Ok(());
79+
anyhow::bail!("sample_name cannot be empty");
6580
}
6681

6782
let list_url = format!("{base_endpoint}/physical_samples/");
@@ -102,7 +117,7 @@ pub async fn ensure_physical_sample_link(
102117
let link_url = format!("{base_endpoint}/data_entries/physical_sample");
103118
let link_body = LinkPhysicalSampleRequest {
104119
data_ids: vec![data_id.to_string()],
105-
physical_sample_id: sample_id,
120+
physical_sample_id: sample_id.clone(),
106121
};
107122

108123
client
@@ -115,5 +130,63 @@ pub async fn ensure_physical_sample_link(
115130
.error_for_status()
116131
.context("physical sample link returned error status")?;
117132

133+
Ok(sample_id)
134+
}
135+
136+
/// Updates the project's tracking_physical_sample_id in its configuration.
137+
/// Fetches current configuration, updates the tracking sample, and POSTs back.
138+
pub async fn update_project_tracking_sample(
139+
client: &Client,
140+
base_endpoint: &str,
141+
api_key: &str,
142+
project_id: &str,
143+
physical_sample_id: &str,
144+
) -> Result<()> {
145+
// GET /projects/ to find the project and its current configuration
146+
let projects_url = format!("{base_endpoint}/projects/");
147+
let projects: Vec<ProjectSummary> = client
148+
.get(&projects_url)
149+
.header("X-API-KEY", api_key)
150+
.send()
151+
.await
152+
.context("failed to request projects")?
153+
.error_for_status()
154+
.context("projects list returned error status")?
155+
.json()
156+
.await
157+
.context("failed to deserialize projects list")?;
158+
159+
// Find the project by ID
160+
let project = projects
161+
.into_iter()
162+
.find(|p| p.id == project_id)
163+
.ok_or_else(|| anyhow::anyhow!("project with id {} not found", project_id))?;
164+
165+
// Build updated configuration, preserving existing fields
166+
let mut config = match project.configuration {
167+
Some(data_config) => data_config.api_configuration.unwrap_or_else(|| Value::Object(Default::default())),
168+
None => Value::Object(Default::default()),
169+
};
170+
171+
// Update tracking_physical_sample_id in the configuration
172+
if let Value::Object(ref mut map) = config {
173+
map.insert(
174+
"tracking_physical_sample_id".to_string(),
175+
Value::String(physical_sample_id.to_string()),
176+
);
177+
}
178+
179+
// POST /projects/{project_id}/configuration with the updated config
180+
let config_url = format!("{base_endpoint}/projects/{project_id}/configuration");
181+
client
182+
.post(&config_url)
183+
.header("X-API-KEY", api_key)
184+
.json(&config)
185+
.send()
186+
.await
187+
.context("failed to update project configuration")?
188+
.error_for_status()
189+
.context("project configuration update returned error status")?;
190+
118191
Ok(())
119192
}

src/atomscale/streaming/src/lib.rs

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@ mod utils;
1313
use utils::{generic_post, init_tracing_once};
1414

1515
mod initialize;
16-
use initialize::{ensure_physical_sample_link, post_for_initialization, RHEEDStreamSettings};
16+
use initialize::{
17+
ensure_physical_sample_link, post_for_initialization, update_project_tracking_sample,
18+
RHEEDStreamSettings,
19+
};
1720

1821
mod upload;
1922
use upload::{
@@ -109,7 +112,7 @@ impl RHEEDStreamer {
109112
}
110113

111114
////Initialize stream
112-
/// initialize(self, stream_name: Optional[str] = None, fps: float, rotations_per_min: float, chunk_size: int, physical_sample: Optional[str] = None) -> str
115+
/// initialize(self, fps: float, rotations_per_min: float, chunk_size: int, stream_name: Optional[str] = None, physical_sample: Optional[str] = None, project_id: Optional[str] = None) -> str
113116
///
114117
/// Creates a new **remote data item** for this stream and returns its `data_id`.
115118
/// Also captures runtime configuration used for subsequent chunk uploads.
@@ -121,21 +124,24 @@ impl RHEEDStreamer {
121124
/// After streaming via `run(...)` or `push(...)`, call `finalize(data_id)` to mark the stream as complete.
122125
///
123126
/// Args:
124-
/// stream_name (Optional[str]): Human-readable name shown in the platform. If `None` or an empty string,
125-
/// a default like `"RHEED Stream @ 1:23PM"` is used.
126127
/// fps (float): Capture rate in frames per second.
127128
/// rotations_per_min (float): Wafer/crystal rotations per minute; use `0.0` for stationary operation.
128129
/// chunk_size (int): The **intended** number of frames per chunk you will send with `run(...)` or `push(...)`.
130+
/// stream_name (Optional[str]): Human-readable name shown in the platform. If `None` or an empty string,
131+
/// a default like `"RHEED Stream @ 1:23PM"` is used.
129132
/// physical_sample (Optional[str]): Name of a physical sample to associate with the data item; matched case-insensitively or created if missing.
133+
/// project_id (Optional[str]): UUID of a project to associate with the stream. When provided along with
134+
/// `physical_sample`, the project's `tracking_physical_sample_id` configuration is automatically updated
135+
/// to link the physical sample to the project for growth monitoring.
130136
///
131137
/// Returns:
132138
/// str: The created `data_id` for this stream.
133139
///
134140
/// Raises:
135141
/// RuntimeError: If the initialization POST fails.
136-
#[pyo3(signature = (fps, rotations_per_min, chunk_size, stream_name=None, physical_sample=None))]
142+
#[pyo3(signature = (fps, rotations_per_min, chunk_size, stream_name=None, physical_sample=None, project_id=None))]
137143
#[pyo3(
138-
text_signature = "(fps, rotations_per_min, chunk_size, stream_name=None, physical_sample=None)"
144+
text_signature = "(fps, rotations_per_min, chunk_size, stream_name=None, physical_sample=None, project_id=None)"
139145
)]
140146
fn initialize(
141147
&mut self,
@@ -144,6 +150,7 @@ impl RHEEDStreamer {
144150
chunk_size: usize,
145151
stream_name: Option<String>,
146152
physical_sample: Option<String>,
153+
project_id: Option<String>,
147154
) -> PyResult<String> {
148155
// Guard: chunk_size must be >= ceil(2 * fps)
149156
let min_chunk = (2.0 * fps).ceil() as usize;
@@ -165,6 +172,10 @@ impl RHEEDStreamer {
165172
.map(|s| s.trim().to_string())
166173
.filter(|s| !s.is_empty());
167174

175+
let project_id = project_id
176+
.map(|s| s.trim().to_string())
177+
.filter(|s| !s.is_empty());
178+
168179
let fpr = (fps * 60.0) / rotations_per_min;
169180

170181
#[allow(clippy::redundant_field_names)]
@@ -173,6 +184,7 @@ impl RHEEDStreamer {
173184
rotational_period: fpr,
174185
rotations_per_min,
175186
fps_capture_rate: fps,
187+
project_id,
176188
};
177189

178190
let base_endpoint = self.endpoint.clone();
@@ -192,9 +204,24 @@ impl RHEEDStreamer {
192204
&data_id,
193205
&sample_name,
194206
);
195-
self.rt
207+
let sample_id = self
208+
.rt
196209
.block_on(physical_sample_fut)
197210
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
211+
212+
// If project_id was provided, update the project's tracking_physical_sample_id
213+
if let Some(ref proj_id) = settings.project_id {
214+
let update_project_fut = update_project_tracking_sample(
215+
&self.client,
216+
&base_endpoint,
217+
&self.api_key,
218+
proj_id,
219+
&sample_id,
220+
);
221+
self.rt
222+
.block_on(update_project_fut)
223+
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
224+
}
198225
}
199226

200227
self.fps = Some(fps);

tests/_mock_http_server.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
"""Subprocess-based mock HTTP server for testing Rust HTTP clients.
2+
3+
This module is designed to be run as a subprocess to provide true process
4+
isolation when testing Rust HTTP clients (like reqwest) from Python tests.
5+
6+
Usage:
7+
python -m tests._mock_http_server <port> <json_response>
8+
python -m tests._mock_http_server <port> <json_routes_dict>
9+
10+
The server:
11+
- Listens on 127.0.0.1:<port>
12+
- Prints "READY:<port>" to stdout when ready
13+
- For simple mode: handles one POST request, prints "BODY:<json>" to stdout
14+
- For routes mode: handles multiple requests based on path matching
15+
- Responds with the provided JSON response
16+
"""
17+
import json
18+
import sys
19+
from http.server import BaseHTTPRequestHandler, HTTPServer
20+
21+
22+
class CaptureHandler(BaseHTTPRequestHandler):
23+
"""HTTP handler that captures request body and returns configured response."""
24+
25+
def _handle_request(self, method: str):
26+
"""Common handler for all HTTP methods."""
27+
# Read request body if present
28+
length = int(self.headers.get("Content-Length", 0))
29+
body = self.rfile.read(length) if length > 0 else b""
30+
31+
path = self.path
32+
routes = getattr(self.server, "routes", None)
33+
34+
if routes:
35+
# Routes mode: find matching route
36+
response_data = None
37+
for route_path, route_response in routes.items():
38+
if path.startswith(route_path):
39+
response_data = route_response
40+
break
41+
42+
if response_data is None:
43+
# Default response for unmatched routes
44+
response_data = '""'
45+
46+
# Print request info for debugging
47+
print(f"REQUEST:{method}:{path}:{body.decode() if body else ''}", flush=True)
48+
else:
49+
# Simple mode: single response for all requests
50+
response_data = self.server.response_data
51+
print(f"BODY:{body.decode()}", flush=True)
52+
53+
# Send response
54+
self.send_response(200)
55+
self.send_header("Content-Type", "application/json")
56+
self.send_header("Content-Length", len(response_data))
57+
self.end_headers()
58+
self.wfile.write(response_data.encode())
59+
60+
def do_GET(self):
61+
self._handle_request("GET")
62+
63+
def do_POST(self):
64+
self._handle_request("POST")
65+
66+
def do_PUT(self):
67+
self._handle_request("PUT")
68+
69+
def log_message(self, format, *args):
70+
"""Suppress default logging."""
71+
pass
72+
73+
74+
class MultiRequestServer(HTTPServer):
75+
"""HTTP server that can handle multiple requests."""
76+
77+
def __init__(self, *args, max_requests: int = 1, **kwargs):
78+
super().__init__(*args, **kwargs)
79+
self.max_requests = max_requests
80+
self.request_count = 0
81+
82+
def handle_requests(self):
83+
"""Handle up to max_requests requests."""
84+
while self.request_count < self.max_requests:
85+
self.handle_request()
86+
self.request_count += 1
87+
88+
89+
def run_server(port: int, response_data: str) -> None:
90+
"""Run the mock HTTP server."""
91+
# Try to parse as routes dict
92+
try:
93+
routes = json.loads(response_data)
94+
if isinstance(routes, dict) and routes.get("__routes__"):
95+
# Routes mode
96+
del routes["__routes__"]
97+
max_requests = routes.pop("__max_requests__", 10)
98+
server = MultiRequestServer(
99+
("127.0.0.1", port), CaptureHandler, max_requests=max_requests
100+
)
101+
server.routes = routes
102+
print(f"READY:{port}", flush=True)
103+
server.handle_requests()
104+
return
105+
except (json.JSONDecodeError, TypeError):
106+
pass
107+
108+
# Simple mode: single request with static response
109+
server = HTTPServer(("127.0.0.1", port), CaptureHandler)
110+
server.response_data = response_data
111+
print(f"READY:{port}", flush=True)
112+
server.handle_request()
113+
114+
115+
if __name__ == "__main__":
116+
if len(sys.argv) != 3:
117+
print(f"Usage: {sys.argv[0]} <port> <json_response>", file=sys.stderr)
118+
sys.exit(1)
119+
120+
port = int(sys.argv[1])
121+
response_data = sys.argv[2]
122+
run_server(port, response_data)

0 commit comments

Comments
 (0)