Skip to content

Commit eaab287

Browse files
committed
Add 'provider' for graph edges
- Update schema to support a 'provider' key - Update existing YAML with placeholder edge providers - Add 'provider' attr and params to Transform associated classes and methods - Add unit tests
1 parent a306b37 commit eaab287

File tree

13 files changed

+866
-267
lines changed

13 files changed

+866
-267
lines changed

schemas/neuromaps.schema.json

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,17 +69,23 @@
6969
"surfaces": {
7070
"type": "object",
7171
"patternProperties": {
72-
"^[0-9]+k?$": {
72+
"^[A-Za-z0-9_]+$": {
7373
"type": "object",
74-
"properties": {
75-
"sphere": {
76-
"$ref": "#/$defs/surfaceHemisphere"
74+
"patternProperties": {
75+
"^[0-9]+k?$": {
76+
"type": "object",
77+
"properties": {
78+
"sphere": {
79+
"$ref": "#/$defs/surfaceHemisphere"
80+
}
81+
},
82+
"required": [
83+
"sphere"
84+
],
85+
"additionalProperties": true
7786
}
7887
},
79-
"required": [
80-
"sphere"
81-
],
82-
"additionalProperties": true
88+
"additionalProperties": false
8389
}
8490
},
8591
"additionalProperties": false
@@ -107,7 +113,13 @@
107113
"description": "Target node."
108114
},
109115
"volumes": {
110-
"$ref": "#/$defs/volumeDict"
116+
"type": "object",
117+
"description": "Transform source.",
118+
"patternProperties": {
119+
"^[A-Za-z0-9_]+$": {
120+
"$ref": "#/$defs/volumeDict"
121+
}
122+
}
111123
}
112124
},
113125
"required": [

src/neuromaps_prime/graph/builder.py

Lines changed: 58 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -196,22 +196,35 @@ def _parse_surface_resources(
196196
Returns:
197197
List of instantiated surface resource objects.
198198
"""
199+
is_transform = cls is SurfaceTransform
199200
prefix = fixed_fields.get("space") or (
200201
f"{fixed_fields['source_space']}_to_{fixed_fields['target_space']}"
201202
)
202-
result = [
203-
cls(
204-
name=f"{prefix}_{density}_{hemi}_{surf_type}",
205-
file_path=self._resolve_path(path),
206-
density=density,
207-
hemisphere=hemi,
208-
resource_type=surf_type,
209-
**fixed_fields,
210-
)
211-
for density, types in surfaces_dict.items()
212-
for surf_type, hemispheres in types.items()
213-
for hemi, path in hemispheres.items()
214-
]
203+
204+
result = []
205+
for outer_key, outer_val in surfaces_dict.items():
206+
if is_transform:
207+
provider = outer_key
208+
density_dict = outer_val
209+
else:
210+
provider = ""
211+
density_dict = {outer_key: outer_val}
212+
213+
for density, types in density_dict.items():
214+
for surf_type, hemispheres in types.items():
215+
for hemi, path in hemispheres.items():
216+
extra = {"provider": provider} if is_transform else {}
217+
result.append(
218+
cls(
219+
name=f"{prefix}_{density}_{hemi}_{surf_type}",
220+
file_path=self._resolve_path(path),
221+
density=density,
222+
hemisphere=hemi,
223+
resource_type=surf_type,
224+
**fixed_fields,
225+
**extra,
226+
)
227+
)
215228
if cls is SurfaceAtlas:
216229
return cast(list[SurfaceAtlas], result)
217230
return cast(list[SurfaceTransform], result)
@@ -222,30 +235,49 @@ def _parse_volume_resources(
222235
fixed_fields: dict[str, Any],
223236
volumes_dict: dict[str, Any],
224237
) -> list[VolumeAtlas] | list[VolumeTransform]:
225-
"""Parse volume resource entries from a nested resolution/type dict.
238+
"""Parse volume resource entries from a nested dict.
239+
240+
Supports both the legacy format (resolution → resource_type) used for
241+
node atlases, and the provider-prefixed format
242+
(provider → resolution → resource_type) used for edge transforms.
226243
227244
Args:
228245
cls: The model class to instantiate (VolumeAtlas or VolumeTransform).
229246
fixed_fields: Fields shared by every entry (e.g. space, description).
230-
volumes_dict: Nested dict keyed by resolution → resource_type.
247+
volumes_dict: Nested dict, either ``{resolution: {type: path}}``
248+
or ``{provider: {resolution: {type: path}}}``.
231249
232250
Returns:
233251
List of instantiated volume resource objects.
234252
"""
253+
is_transform = cls is VolumeTransform
235254
prefix = fixed_fields.get("space") or (
236255
f"{fixed_fields['source_space']}_to_{fixed_fields['target_space']}"
237256
)
238-
result = [
239-
cls(
240-
name=f"{prefix}_{res}_{vol_type}",
241-
file_path=self._resolve_path(path),
242-
resolution=res,
243-
resource_type=vol_type,
244-
**fixed_fields,
245-
)
246-
for res, types in volumes_dict.items()
247-
for vol_type, path in types.items()
248-
]
257+
258+
result: list[Any] = []
259+
for outer_key, outer_val in volumes_dict.items():
260+
if is_transform:
261+
provider = outer_key
262+
resolution_dict = outer_val
263+
else:
264+
provider = ""
265+
resolution_dict = {outer_key: outer_val}
266+
267+
for res, types in resolution_dict.items():
268+
for vol_type, path in types.items():
269+
extra = {"provider": provider} if is_transform else {}
270+
result.append(
271+
cls(
272+
name=f"{prefix}_{res}_{vol_type}",
273+
file_path=self._resolve_path(path),
274+
resolution=res,
275+
resource_type=vol_type,
276+
**fixed_fields, # type: ignore[arg-type]
277+
**extra, # type: ignore[arg-type]
278+
)
279+
)
280+
249281
if cls is VolumeAtlas:
250282
return cast(list[VolumeAtlas], result)
251283
return cast(list[VolumeTransform], result)

src/neuromaps_prime/graph/cache.py

Lines changed: 64 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919

2020
# Key type aliases
2121
SurfaceAtlasKey = tuple[str, str, str, str]
22-
SurfaceTransformKey = tuple[str, str, str, str, str]
22+
SurfaceTransformKey = tuple[str, str, str, str, str, str]
2323
VolumeAtlasKey = tuple[str, str, str]
24-
VolumeTransformKey = tuple[str, str, str, str]
24+
VolumeTransformKey = tuple[str, str, str, str, str]
2525

2626

2727
class GraphCache(BaseModel):
@@ -38,13 +38,13 @@ class GraphCache(BaseModel):
3838
Maps ``(space, density, hemisphere, resource_type)`` to a
3939
:class:`SurfaceAtlas`.
4040
surface_transform:
41-
Maps ``(source, target, density, hemisphere, resource_type)`` to a
41+
Maps ``(source, target, density, hemisphere, resource_type, provider)`` to a
4242
:class:`SurfaceTransform`.
4343
volume_atlas:
4444
Maps ``(space, resolution, resource_type)`` to a
4545
:class:`VolumeAtlas`.
4646
volume_transform:
47-
Maps ``(source, target, resolution, resource_type)`` to a
47+
Maps ``(source, target, resolution, resource_type, provider)`` to a
4848
:class:`VolumeTransform`.
4949
"""
5050

@@ -150,6 +150,7 @@ def add_surface_transform(self, transform: SurfaceTransform) -> None:
150150
transform.density,
151151
transform.hemisphere.lower(),
152152
transform.resource_type,
153+
transform.provider,
153154
)
154155
] = transform
155156

@@ -160,11 +161,30 @@ def get_surface_transform(
160161
density: str,
161162
hemisphere: Literal["left", "right"],
162163
resource_type: str,
164+
provider: str | None = None,
163165
) -> SurfaceTransform | None:
164-
"""Return the matching :class:`SurfaceTransform`, or ``None``."""
165-
return self.surface_transform.get(
166-
(source, target, density, hemisphere.lower(), resource_type)
167-
)
166+
"""Return the matching :class:`SurfaceTransform`, or ``None``.
167+
168+
If *provider* is ``None`` or not found, falls back to the first
169+
registered transform matching the other fields.
170+
"""
171+
if provider is not None:
172+
result = self.surface_transform.get(
173+
(source, target, density, hemisphere.lower(), resource_type, provider)
174+
)
175+
if result is not None:
176+
return result
177+
# Fallback: first match ignoring provider
178+
for (src, tgt, d, h, rt, _), transform in self.surface_transform.items():
179+
if (
180+
src == source
181+
and tgt == target
182+
and d == density
183+
and h == hemisphere.lower()
184+
and rt == resource_type
185+
):
186+
return transform
187+
return None
168188

169189
def get_surface_transforms(
170190
self,
@@ -173,6 +193,7 @@ def get_surface_transforms(
173193
density: str | None = None,
174194
hemisphere: Literal["left", "right"] | None = None,
175195
resource_type: str | None = None,
196+
provider: str | None = None,
176197
) -> list[SurfaceTransform]:
177198
"""Return all surface transforms between two spaces with optional filters.
178199
@@ -182,18 +203,20 @@ def get_surface_transforms(
182203
density: Optional density filter.
183204
hemisphere: Optional hemisphere filter.
184205
resource_type: Optional resource type filter.
206+
provider: Optional provider filter.
185207
186208
Returns:
187209
All matching :class:`SurfaceTransform` entries (may be empty).
188210
"""
189211
return [
190212
transform
191-
for (src, tgt, d, h, rt), transform in self.surface_transform.items()
213+
for (src, tgt, d, h, rt, prov), transform in self.surface_transform.items()
192214
if src == source
193215
and tgt == target
194216
and (density is None or d == density)
195217
and (hemisphere is None or h == hemisphere.lower())
196218
and (resource_type is None or rt == resource_type)
219+
and (provider is None or prov == provider)
197220
]
198221

199222
# ------------------------------------------------------------------ #
@@ -246,21 +269,47 @@ def add_volume_transform(self, transform: VolumeTransform) -> None:
246269
transform.target_space,
247270
transform.resolution,
248271
transform.resource_type,
272+
transform.provider,
249273
)
250274
] = transform
251275

252276
def get_volume_transform(
253-
self, source: str, target: str, resolution: str, resource_type: str
277+
self,
278+
source: str,
279+
target: str,
280+
resolution: str,
281+
resource_type: str,
282+
provider: str | None,
254283
) -> VolumeTransform | None:
255-
"""Return the matching :class:`VolumeTransform`, or ``None``."""
256-
return self.volume_transform.get((source, target, resolution, resource_type))
284+
"""Return the matching :class:`VolumeTransform`, or ``None``.
285+
286+
If *provider* is ``None`` or not found, falls back to the first
287+
registered transform matching the other fields.
288+
"""
289+
if provider is not None:
290+
result = self.volume_transform.get(
291+
(source, target, resolution, resource_type, provider)
292+
)
293+
if result is not None:
294+
return result
295+
# Fallback: first match ignoring provider
296+
for (src, tgt, res, rt, _), transform in self.volume_transform.items():
297+
if (
298+
src == source
299+
and tgt == target
300+
and res == resolution
301+
and rt == resource_type
302+
):
303+
return transform
304+
return None
257305

258306
def get_volume_transforms(
259307
self,
260308
source: str,
261309
target: str,
262310
resolution: str | None = None,
263311
resource_type: str | None = None,
312+
provider: str | None = None,
264313
) -> list[VolumeTransform]:
265314
"""Return all volume transforms between two spaces with optional filters.
266315
@@ -269,17 +318,19 @@ def get_volume_transforms(
269318
target: Target brain template space name.
270319
resolution: Optional resolution filter.
271320
resource_type: Optional resource type filter.
321+
provider: Optional provider filter.
272322
273323
Returns:
274324
All matching :class:`VolumeTransform` entries (may be empty).
275325
"""
276326
return [
277327
transform
278-
for (src, tgt, res, rt), transform in self.volume_transform.items()
328+
for (src, tgt, res, rt, prov), transform in self.volume_transform.items()
279329
if src == source
280330
and tgt == target
281331
and (resolution is None or res == resolution)
282332
and (resource_type is None or rt == resource_type)
333+
and (provider is None or prov == provider)
283334
]
284335

285336
# ------------------------------------------------------------------ #

0 commit comments

Comments
 (0)