@@ -35,7 +35,7 @@ def __init__(
3535 self ,
3636 base_dir : os .PathLike ,
3737 n_workers : Optional [int ] = 1 ,
38- buffer_ratio : Optional [float ] = 1.0 ,
38+ scale_factor : Optional [float ] = 1.0 ,
3939 sample_type : str = None ,
4040 weights : pd .DataFrame = None ,
4141 ):
@@ -52,8 +52,8 @@ def __init__(
5252 The sample type of the raw data, e.g., 'xenium' or 'merscope'.
5353 weights : Optional[pd.DataFrame], default None
5454 DataFrame containing weights for transcript embedding.
55- buffer_ratio : Optional[float], default None
56- The buffer ratio to be used for expanding the boundary extents
55+ scale_factor : Optional[float], default None
56+ The scale factor to be used for expanding the boundary extents
5757 during spatial queries. If not provided, the default from settings
5858 will be used.
5959
@@ -71,15 +71,15 @@ def __init__(
7171 boundaries_fn = self .settings .boundaries .filename
7272 self ._boundaries_filepath = self ._base_dir / boundaries_fn
7373 self .n_workers = n_workers
74- self .settings .boundaries .buffer_ratio = 1
74+ self .settings .boundaries .scale_factor = 1
7575 nuclear_column = getattr (self .settings .transcripts , "nuclear_column" , None )
76- if nuclear_column is None or self .settings .boundaries .buffer_ratio != 1.0 :
76+ if nuclear_column is None or self .settings .boundaries .scale_factor != 1.0 :
7777 print (
7878 "Boundary-transcript overlap information has not been pre-computed. It will be calculated during tile generation."
7979 )
80- # Set buffer ratio if provided
81- if buffer_ratio != 1.0 :
82- self .settings .boundaries .buffer_ratio = buffer_ratio
80+ # Set scale factor if provided
81+ if scale_factor != 1.0 :
82+ self .settings .boundaries .scale_factor = scale_factor
8383
8484 # Ensure transcript IDs exist
8585 utils .ensure_transcript_ids (
@@ -1164,13 +1164,12 @@ def get_boundary_props(
11641164 of the code.
11651165 """
11661166 # Get polygons from coordinates
1167- polygons = utils .get_polygons_from_xy (
1168- self .boundaries ,
1169- x = self .settings .boundaries .x ,
1170- y = self .settings .boundaries .y ,
1171- label = self .settings .boundaries .label ,
1172- buffer_ratio = self .settings .boundaries .buffer_ratio ,
1173- )
1167+ # Use getattr to check for the geometry column
1168+ geometry_column = getattr (self .settings .boundaries , 'geometry' , None )
1169+ if geometry_column and geometry_column in self .boundaries .columns :
1170+ polygons = self .boundaries [geometry_column ]
1171+ else :
1172+ polygons = self .boundaries ['geometry' ] # Assign None if the geometry column does not exist
11741173 # Geometric properties of polygons
11751174 props = self .get_polygon_props (polygons )
11761175 props = torch .as_tensor (props .values ).float ()
@@ -1230,13 +1229,22 @@ def to_pyg_dataset(
12301229 pyg_data ["tx" , "neighbors" , "tx" ].edge_index = nbrs_edge_idx
12311230
12321231 # Set up Boundary nodes
1233- polygons = utils .get_polygons_from_xy (
1234- self .boundaries ,
1235- self .settings .boundaries .x ,
1236- self .settings .boundaries .y ,
1237- self .settings .boundaries .label ,
1238- self .settings .boundaries .buffer_ratio ,
1239- )
1232+ # Check if boundaries have geometries
1233+ geometry_column = getattr (self .settings .boundaries , 'geometry' , None )
1234+ if geometry_column and geometry_column in self .boundaries .columns :
1235+ polygons = gpd .GeoSeries (self .boundaries [geometry_column ], index = self .boundaries .index )
1236+ else :
1237+ # Fallback: compute polygons
1238+ polygons = utils .get_polygons_from_xy (
1239+ self .boundaries ,
1240+ x = self .settings .boundaries .x ,
1241+ y = self .settings .boundaries .y ,
1242+ label = self .settings .boundaries .label ,
1243+ scale_factor = self .settings .boundaries .scale_factor ,
1244+ )
1245+
1246+ # Ensure self.boundaries is a GeoDataFrame with correct geometry
1247+ self .boundaries = gpd .GeoDataFrame (self .boundaries .copy (), geometry = polygons )
12401248 centroids = polygons .centroid .get_coordinates ()
12411249 pyg_data ["bd" ].id = polygons .index .to_numpy ()
12421250 pyg_data ["bd" ].pos = torch .tensor (centroids .values , dtype = torch .float32 )
@@ -1273,7 +1281,7 @@ def to_pyg_dataset(
12731281 nuclear_column = getattr (self .settings .transcripts , "nuclear_column" , None )
12741282 nuclear_value = getattr (self .settings .transcripts , "nuclear_value" , None )
12751283
1276- if nuclear_column is None or self .settings .boundaries .buffer_ratio != 1.0 :
1284+ if nuclear_column is None or self .settings .boundaries .scale_factor != 1.0 :
12771285 is_nuclear = utils .compute_nuclear_transcripts (
12781286 polygons = polygons ,
12791287 transcripts = self .transcripts ,
0 commit comments