Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions schemas/neuromaps.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,23 @@
"surfaces": {
"type": "object",
"patternProperties": {
"^[0-9]+k?$": {
"^[A-Za-z0-9_]+$": {
"type": "object",
"properties": {
"sphere": {
"$ref": "#/$defs/surfaceHemisphere"
"patternProperties": {
"^[0-9]+k?$": {
"type": "object",
"properties": {
"sphere": {
"$ref": "#/$defs/surfaceHemisphere"
}
},
"required": [
"sphere"
],
"additionalProperties": true
}
},
"required": [
"sphere"
],
"additionalProperties": true
"additionalProperties": false
}
},
"additionalProperties": false
Expand Down Expand Up @@ -107,7 +113,13 @@
"description": "Target node."
},
"volumes": {
"$ref": "#/$defs/volumeDict"
"type": "object",
"description": "Transform source.",
"patternProperties": {
"^[A-Za-z0-9_]+$": {
"$ref": "#/$defs/volumeDict"
}
}
}
},
"required": [
Expand Down
84 changes: 58 additions & 26 deletions src/neuromaps_prime/graph/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,22 +196,35 @@ def _parse_surface_resources(
Returns:
List of instantiated surface resource objects.
"""
is_transform = cls is SurfaceTransform
prefix = fixed_fields.get("space") or (
f"{fixed_fields['source_space']}_to_{fixed_fields['target_space']}"
)
result = [
cls(
name=f"{prefix}_{density}_{hemi}_{surf_type}",
file_path=self._resolve_path(path),
density=density,
hemisphere=hemi,
resource_type=surf_type,
**fixed_fields,
)
for density, types in surfaces_dict.items()
for surf_type, hemispheres in types.items()
for hemi, path in hemispheres.items()
]

result = []
for outer_key, outer_val in surfaces_dict.items():
if is_transform:
provider = outer_key
density_dict = outer_val
else:
provider = ""
density_dict = {outer_key: outer_val}

for density, types in density_dict.items():
for surf_type, hemispheres in types.items():
for hemi, path in hemispheres.items():
extra = {"provider": provider} if is_transform else {}
result.append(
cls(
name=f"{prefix}_{density}_{hemi}_{surf_type}",
file_path=self._resolve_path(path),
density=density,
hemisphere=hemi,
resource_type=surf_type,
**fixed_fields, # type: ignore[arg-type]
**extra, # type: ignore[arg-type]
)
)
if cls is SurfaceAtlas:
return cast(list[SurfaceAtlas], result)
return cast(list[SurfaceTransform], result)
Expand All @@ -222,30 +235,49 @@ def _parse_volume_resources(
fixed_fields: dict[str, Any],
volumes_dict: dict[str, Any],
) -> list[VolumeAtlas] | list[VolumeTransform]:
"""Parse volume resource entries from a nested resolution/type dict.
"""Parse volume resource entries from a nested dict.

Supports both the legacy format (resolution → resource_type) used for
node atlases, and the provider-prefixed format
(provider → resolution → resource_type) used for edge transforms.

Args:
cls: The model class to instantiate (VolumeAtlas or VolumeTransform).
fixed_fields: Fields shared by every entry (e.g. space, description).
volumes_dict: Nested dict keyed by resolution → resource_type.
volumes_dict: Nested dict, either ``{resolution: {type: path}}``
or ``{provider: {resolution: {type: path}}}``.

Returns:
List of instantiated volume resource objects.
"""
is_transform = cls is VolumeTransform
prefix = fixed_fields.get("space") or (
f"{fixed_fields['source_space']}_to_{fixed_fields['target_space']}"
)
result = [
cls(
name=f"{prefix}_{res}_{vol_type}",
file_path=self._resolve_path(path),
resolution=res,
resource_type=vol_type,
**fixed_fields,
)
for res, types in volumes_dict.items()
for vol_type, path in types.items()
]

result: list[Any] = []
for outer_key, outer_val in volumes_dict.items():
if is_transform:
provider = outer_key
resolution_dict = outer_val
else:
provider = ""
resolution_dict = {outer_key: outer_val}

for res, types in resolution_dict.items():
for vol_type, path in types.items():
extra = {"provider": provider} if is_transform else {}
result.append(
cls(
name=f"{prefix}_{res}_{vol_type}",
file_path=self._resolve_path(path),
resolution=res,
resource_type=vol_type,
**fixed_fields, # type: ignore[arg-type]
**extra, # type: ignore[arg-type]
)
)

if cls is VolumeAtlas:
return cast(list[VolumeAtlas], result)
return cast(list[VolumeTransform], result)
77 changes: 64 additions & 13 deletions src/neuromaps_prime/graph/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@

# Key type aliases
SurfaceAtlasKey = tuple[str, str, str, str]
SurfaceTransformKey = tuple[str, str, str, str, str]
SurfaceTransformKey = tuple[str, str, str, str, str, str]
VolumeAtlasKey = tuple[str, str, str]
VolumeTransformKey = tuple[str, str, str, str]
VolumeTransformKey = tuple[str, str, str, str, str]


class GraphCache(BaseModel):
Expand All @@ -38,13 +38,13 @@ class GraphCache(BaseModel):
Maps ``(space, density, hemisphere, resource_type)`` to a
:class:`SurfaceAtlas`.
surface_transform:
Maps ``(source, target, density, hemisphere, resource_type)`` to a
Maps ``(source, target, density, hemisphere, resource_type, provider)`` to a
:class:`SurfaceTransform`.
volume_atlas:
Maps ``(space, resolution, resource_type)`` to a
:class:`VolumeAtlas`.
volume_transform:
Maps ``(source, target, resolution, resource_type)`` to a
Maps ``(source, target, resolution, resource_type, provider)`` to a
:class:`VolumeTransform`.
"""

Expand Down Expand Up @@ -150,6 +150,7 @@ def add_surface_transform(self, transform: SurfaceTransform) -> None:
transform.density,
transform.hemisphere.lower(),
transform.resource_type,
transform.provider,
)
] = transform

Expand All @@ -160,11 +161,30 @@ def get_surface_transform(
density: str,
hemisphere: Literal["left", "right"],
resource_type: str,
provider: str | None = None,
) -> SurfaceTransform | None:
"""Return the matching :class:`SurfaceTransform`, or ``None``."""
return self.surface_transform.get(
(source, target, density, hemisphere.lower(), resource_type)
)
"""Return the matching :class:`SurfaceTransform`, or ``None``.

If *provider* is ``None`` or not found, falls back to the first
registered transform matching the other fields.
"""
if provider is not None:
result = self.surface_transform.get(
(source, target, density, hemisphere.lower(), resource_type, provider)
)
if result is not None:
return result
# Fallback: first match ignoring provider
for (src, tgt, d, h, rt, _), transform in self.surface_transform.items():
if (
src == source
and tgt == target
and d == density
and h == hemisphere.lower()
and rt == resource_type
):
return transform
return None

def get_surface_transforms(
self,
Expand All @@ -173,6 +193,7 @@ def get_surface_transforms(
density: str | None = None,
hemisphere: Literal["left", "right"] | None = None,
resource_type: str | None = None,
provider: str | None = None,
) -> list[SurfaceTransform]:
"""Return all surface transforms between two spaces with optional filters.

Expand All @@ -182,18 +203,20 @@ def get_surface_transforms(
density: Optional density filter.
hemisphere: Optional hemisphere filter.
resource_type: Optional resource type filter.
provider: Optional provider filter.

Returns:
All matching :class:`SurfaceTransform` entries (may be empty).
"""
return [
transform
for (src, tgt, d, h, rt), transform in self.surface_transform.items()
for (src, tgt, d, h, rt, prov), transform in self.surface_transform.items()
if src == source
and tgt == target
and (density is None or d == density)
and (hemisphere is None or h == hemisphere.lower())
and (resource_type is None or rt == resource_type)
and (provider is None or prov == provider)
]

# ------------------------------------------------------------------ #
Expand Down Expand Up @@ -246,21 +269,47 @@ def add_volume_transform(self, transform: VolumeTransform) -> None:
transform.target_space,
transform.resolution,
transform.resource_type,
transform.provider,
)
] = transform

def get_volume_transform(
self, source: str, target: str, resolution: str, resource_type: str
self,
source: str,
target: str,
resolution: str,
resource_type: str,
provider: str | None,
) -> VolumeTransform | None:
"""Return the matching :class:`VolumeTransform`, or ``None``."""
return self.volume_transform.get((source, target, resolution, resource_type))
"""Return the matching :class:`VolumeTransform`, or ``None``.

If *provider* is ``None`` or not found, falls back to the first
registered transform matching the other fields.
"""
if provider is not None:
result = self.volume_transform.get(
(source, target, resolution, resource_type, provider)
)
if result is not None:
return result
# Fallback: first match ignoring provider
for (src, tgt, res, rt, _), transform in self.volume_transform.items():
if (
src == source
and tgt == target
and res == resolution
and rt == resource_type
):
return transform
return None

def get_volume_transforms(
self,
source: str,
target: str,
resolution: str | None = None,
resource_type: str | None = None,
provider: str | None = None,
) -> list[VolumeTransform]:
"""Return all volume transforms between two spaces with optional filters.

Expand All @@ -269,17 +318,19 @@ def get_volume_transforms(
target: Target brain template space name.
resolution: Optional resolution filter.
resource_type: Optional resource type filter.
provider: Optional provider filter.

Returns:
All matching :class:`VolumeTransform` entries (may be empty).
"""
return [
transform
for (src, tgt, res, rt), transform in self.volume_transform.items()
for (src, tgt, res, rt, prov), transform in self.volume_transform.items()
if src == source
and tgt == target
and (resolution is None or res == resolution)
and (resource_type is None or rt == resource_type)
and (provider is None or prov == provider)
]

# ------------------------------------------------------------------ #
Expand Down
Loading