Skip to content

Commit bab5dbb

Browse files
authored
Merge pull request #58 from Leengit/mask_granularity
BUG: Better alignment accuracy between overlapping tiles and mask. ENH: Allow specification of mask_threshold. ENH: Updated tensorflow_stream.ipynb jupyter lab. PERF: TilesByGridAndMask speed up via numpy array operations.
2 parents d77411b + 990ddbf commit bab5dbb

File tree

2 files changed

+167
-66
lines changed

2 files changed

+167
-66
lines changed

example/tensorflow_stream.ipynb

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,16 @@
3131
"source": [
3232
"!apt update\n",
3333
"!apt install -y python3-openslide openslide-tools\n",
34+
"\n",
3435
"!pip install histomics_stream 'large_image[openslide,ometiff,openjpeg,bioformats]' pooch --find-links https://girder.github.io/large_image_wheels\n",
35-
"!pip install /tf/notebooks/histomics_detect\n",
3636
"\n",
3737
"import sys\n",
38+
"!pip install /tf/notebooks/histomics_detect\n",
3839
"sys.path.append(\"/tf/notebooks/histomics_detect/\")\n",
3940
"\n",
40-
"print(\"\\nNOTE!: On Google Colab you may need to choose 'Runtime->Restart runtime' for these updates to take effect.\")"
41+
"print(\n",
42+
" \"\\nNOTE!: On Google Colab you may need to choose 'Runtime->Restart runtime' for these updates to take effect.\"\n",
43+
")"
4144
]
4245
},
4346
{
@@ -198,7 +201,9 @@
198201
"# Demonstrate TilesByList\n",
199202
"my_study_tiles_by_list = copy.deepcopy(my_study0)\n",
200203
"tiles_by_list = hs.configure.TilesByList(\n",
201-
" my_study_tiles_by_list, randomly_select=5, tiles_dictionary=my_study_tiles_by_grid[\"slides\"][\"Slide_0\"][\"tiles\"]\n",
204+
" my_study_tiles_by_list,\n",
205+
" randomly_select=5,\n",
206+
" tiles_dictionary=my_study_tiles_by_grid[\"slides\"][\"Slide_0\"][\"tiles\"],\n",
202207
")\n",
203208
"# We could apply this to a subset of the slides, but we will apply it to all slides in\n",
204209
"# this example.\n",
@@ -248,6 +253,7 @@
248253
" number_pixel_overlap_rows_for_tile=0,\n",
249254
" number_pixel_overlap_columns_for_tile=0,\n",
250255
" mask_filename=mask_path,\n",
256+
" mask_threshold=0.5,\n",
251257
" randomly_select=100,\n",
252258
")\n",
253259
"for slide in my_study_of_tiles[\"slides\"].values():\n",
@@ -294,7 +300,9 @@
294300
"# restore keras model\n",
295301
"from histomics_detect.models import FasterRCNN\n",
296302
"\n",
297-
"model = tf.keras.models.load_model(model_path, custom_objects={\"FasterRCNN\": FasterRCNN})\n",
303+
"model = tf.keras.models.load_model(\n",
304+
" model_path, custom_objects={\"FasterRCNN\": FasterRCNN}\n",
305+
")\n",
298306
"\n",
299307
"# Each element of the `tiles` tensorflow Dataset is a (rgb_image_data, dictionary_of_annotation) pair.\n",
300308
"# Wrap the unwrapped_model so that it knows to use the image.\n",
@@ -336,7 +344,9 @@
336344
"end_time = time.time()\n",
337345
"number_of_inputs = len([0 for tile in tiles])\n",
338346
"number_of_predictions = predictions[0].shape[0]\n",
339-
"print(f\"Made {number_of_predictions} predictions for {number_of_inputs} tiles in {end_time - start_time} s.\")\n",
347+
"print(\n",
348+
" f\"Made {number_of_predictions} predictions for {number_of_inputs} tiles in {end_time - start_time} s.\"\n",
349+
")\n",
340350
"print(f\"Average of {(end_time - start_time) / number_of_inputs} s per tile.\")"
341351
]
342352
},

histomics_stream/configure.py

Lines changed: 152 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
class FindResolutionForSlide:
1010
"""A class that computes read parameters for slides.
1111
12-
An instance of class FindResolutionForSlide is a callable that will add level,
13-
factor, number_pixel_rows_for_slide, and number_pixel_columns_for_slide fields to a
14-
slide dictionary.
12+
An instance of class FindResolutionForSlide is a callable that
13+
will add level, factor, number_pixel_rows_for_slide, and
14+
number_pixel_columns_for_slide fields to a slide dictionary.
1515
1616
Parameters for the constructor
1717
------------------------------
@@ -253,6 +253,10 @@ class TilesByGridAndMask:
253253
image ( in terms of its grid of tiles). A non-zero value in
254254
the mask indicates that the tile should be retained. The
255255
default is "", which means that there is no masking.
256+
mask_threshold : float
257+
A value in [0.0, 1.1]. A tile is retained if the fraction of
258+
the tile overlapping non-zero pixels in the mask is at least
259+
the mask_threshold.
256260
257261
"""
258262

@@ -263,6 +267,7 @@ def __init__(
263267
number_pixel_overlap_rows_for_tile=0, # Defaults to no overlap between adjacent tiles
264268
number_pixel_overlap_columns_for_tile=0,
265269
mask_filename="", # Defaults to no masking
270+
mask_threshold=0.0, # Defaults to any overlap with the mask
266271
):
267272
"""Sanity check the supplied parameters and store them for later use."""
268273
# Check values.
@@ -314,8 +319,16 @@ def __init__(
314319
mask_itk = itk.imread(mask_filename) # May throw exception
315320
if mask_itk.GetImageDimension() != 2:
316321
raise ValueError(
317-
f"The mask ({mask_filename})" " should be a 2-dimensional image."
322+
f"The mask ({mask_filename}) should be a 2-dimensional image."
318323
)
324+
if not (
325+
isinstance(mask_threshold, float)
326+
and mask_threshold >= 0.0
327+
and mask_threshold <= 1.0
328+
):
329+
raise ValueError(
330+
f"mask_threshold ({mask_threshold}) must be between 0 and 1 inclusive."
331+
)
319332

320333
# Save values. To keep garbage collection efficient don't save all of `study`.
321334
self.number_pixel_rows_for_tile = study["number_pixel_rows_for_tile"]
@@ -326,8 +339,9 @@ def __init__(
326339
number_pixel_overlap_columns_for_tile
327340
)
328341
self.mask_filename = mask_filename
329-
if mask_filename != "":
342+
if self.mask_filename != "":
330343
self.mask_itk = mask_itk
344+
self.mask_threshold = mask_threshold
331345

332346
def __call__(self, slide):
333347
"""Select tiles according to a regular grid. Optionally, restrict the list by a mask.
@@ -338,92 +352,109 @@ def __call__(self, slide):
338352
raise ValueError(
339353
'slide["number_pixel_rows_for_slide"] must be already set.'
340354
)
355+
self.number_pixel_rows_for_slide = slide["number_pixel_rows_for_slide"]
341356
if "number_pixel_columns_for_slide" not in slide:
342357
raise ValueError(
343358
'slide["number_pixel_columns_for_slide"] must be already set.'
344359
)
345360

361+
self.number_pixel_columns_for_slide = slide["number_pixel_columns_for_slide"]
362+
#
346363
# Do the work.
364+
#
347365
row_stride = (
348366
self.number_pixel_rows_for_tile - self.number_pixel_overlap_rows_for_tile
349367
)
350-
number_tile_rows_for_slide = slide["number_tile_rows_for_slide"] = math.floor(
351-
(
352-
slide["number_pixel_rows_for_slide"]
353-
- self.number_pixel_overlap_rows_for_tile
354-
)
355-
/ row_stride
356-
)
357368
column_stride = (
358369
self.number_pixel_columns_for_tile
359370
- self.number_pixel_overlap_columns_for_tile
360371
)
361-
number_tile_columns_for_slide = slide[
362-
"number_tile_columns_for_slide"
363-
] = math.floor(
372+
373+
# Return information to the user
374+
slide["number_tile_rows_for_slide"] = math.floor(
375+
(self.number_pixel_rows_for_slide - self.number_pixel_overlap_rows_for_tile)
376+
/ row_stride
377+
)
378+
slide["number_tile_columns_for_slide"] = math.floor(
364379
(
365-
slide["number_pixel_columns_for_slide"]
380+
self.number_pixel_columns_for_slide
366381
- self.number_pixel_overlap_columns_for_tile
367382
)
368383
/ column_stride
369384
)
385+
386+
# Pre-process the mask
370387
has_mask = hasattr(self, "mask_itk")
371388
if has_mask:
372-
# We will change the resolution of the mask (if necessary), which will
373-
# change the number of pixels, but will not change the overall physical size
374-
# represented by the image nor the position of the upper left corner of its
375-
# upper left pixel.
376-
input_size = itk.size(self.mask_itk)
377-
output_size = [number_tile_columns_for_slide, number_tile_rows_for_slide]
378-
if input_size != output_size:
379-
# print(f"Resampling from input_size = {input_size} to output_size = {output_size}")
380-
# Check that the input and output aspect ratios are pretty close
381-
if (
382-
abs(
383-
math.log(
384-
(output_size[0] / input_size[0])
385-
/ (output_size[1] / input_size[1])
389+
(
390+
self.number_pixel_rows_for_mask,
391+
self.number_pixel_columns_for_mask,
392+
) = self.mask_itk.shape
393+
slide["number_pixel_rows_for_mask"] = self.number_pixel_rows_for_mask
394+
slide["number_pixel_columns_for_mask"] = self.number_pixel_columns_for_mask
395+
396+
# Check that the input and output aspect ratios are pretty close
397+
if (
398+
abs(
399+
math.log(
400+
(
401+
self.number_pixel_columns_for_slide
402+
/ self.number_pixel_columns_for_mask
403+
)
404+
/ (
405+
self.number_pixel_rows_for_slide
406+
/ self.number_pixel_rows_for_mask
386407
)
387408
)
388-
> 0.20
389-
):
390-
raise ValueError(
391-
"The mask aspect ratio does not match that for the number of tiles."
392-
)
393-
input_spacing = itk.spacing(self.mask_itk)
394-
input_origin = itk.origin(self.mask_itk)
395-
image_dimension = self.mask_itk.GetImageDimension()
396-
output_spacing = [
397-
input_spacing[d] * input_size[d] / output_size[d]
398-
for d in range(image_dimension)
399-
]
400-
output_origin = [
401-
input_origin[d] + 0.5 * (output_spacing[d] - input_spacing[d])
402-
for d in range(image_dimension)
403-
]
404-
interpolator = itk.NearestNeighborInterpolateImageFunction.New(
405-
self.mask_itk
406409
)
407-
resampled_mask_itk = itk.resample_image_filter(
408-
self.mask_itk,
409-
interpolator=interpolator,
410-
size=output_size,
411-
output_spacing=output_spacing,
412-
output_origin=output_origin,
410+
> 0.20
411+
):
412+
raise ValueError(
413+
"The mask aspect ratio does not match that for the whole slide image."
413414
)
414-
else:
415-
resampled_mask_itk = self.mask_itk
416415

416+
# cumulative_mask[row, column] will be the number of mask_itk[r, c] (i.e.,
417+
# mask_itk.GetPixel((c,r))) values that are nonzero among all those with r <
418+
# row and c < column; note the strict inequalities. We have added a
419+
# boundary on all sides of this array -- zeros on the top and left, and a
420+
# duplicate row (column) on the bottom (right) -- so that we do not need to
421+
# do extra testing in our code at the borders. We use int64 in case there
422+
# are 2^31 (~2 billion = ~ 46k by 46k) or more non-zero pixel values in our
423+
# mask.
424+
self.cumulative_mask = np.zeros(
425+
(
426+
self.number_pixel_rows_for_mask + 2,
427+
self.number_pixel_columns_for_mask + 2,
428+
),
429+
dtype=np.int64,
430+
)
431+
nonzero = np.vectorize(lambda x: int(x != 0))
432+
self.cumulative_mask[
433+
1 : self.number_pixel_rows_for_mask + 1,
434+
1 : self.number_pixel_columns_for_mask + 1,
435+
] = nonzero(itk.GetArrayViewFromImage(self.mask_itk))
436+
self.cumulative_mask = np.cumsum(
437+
np.cumsum(self.cumulative_mask, axis=0), axis=1
438+
)
439+
440+
# Look at each tile in turn
417441
tiles = slide["tiles"] = {}
418442
number_of_tiles = 0
419-
for row in range(number_tile_rows_for_slide):
420-
for column in range(number_tile_columns_for_slide):
421-
if not (has_mask and resampled_mask_itk[row, column] == 0):
443+
top_too_high = (
444+
self.number_pixel_rows_for_slide - self.number_pixel_rows_for_tile + 1
445+
)
446+
left_too_high = (
447+
self.number_pixel_columns_for_slide - self.number_pixel_columns_for_tile + 1
448+
)
449+
for top in range(0, top_too_high, row_stride):
450+
for left in range(0, left_too_high, column_stride):
451+
if not (has_mask and self.mask_rejects(top, left)):
422452
tiles[f"tile_{number_of_tiles}"] = {
423-
"tile_top": row * row_stride,
424-
"tile_left": column * column_stride,
453+
"tile_top": top,
454+
"tile_left": left,
425455
}
426456
number_of_tiles += 1 # Increment even if tile is skipped.
457+
427458
# Choose a subset of the tiles randomly
428459
all_tile_names = tiles.keys()
429460
if 0 <= self.randomly_select < len(all_tile_names):
@@ -433,6 +464,66 @@ def __call__(self, slide):
433464
for key in keys_to_remove:
434465
del tiles[key]
435466

467+
def interpolate_cumulative(self, row, column):
468+
top = int(math.floor(row))
469+
left = int(math.floor(column))
470+
vertical_range = row - top
471+
horizontal_range = column - left
472+
response = (
473+
self.cumulative_mask[top, left]
474+
* (1.0 - vertical_range)
475+
* (1.0 - horizontal_range)
476+
+ self.cumulative_mask[top + 1, left]
477+
* vertical_range
478+
* (1.0 - horizontal_range)
479+
+ self.cumulative_mask[top, left + 1]
480+
* (1.0 - vertical_range)
481+
* horizontal_range
482+
+ self.cumulative_mask[top + 1, left + 1]
483+
* vertical_range
484+
* horizontal_range
485+
)
486+
return response
487+
488+
def mask_rejects(self, top, left):
489+
bottom = top + self.number_pixel_rows_for_tile
490+
right = left + self.number_pixel_columns_for_tile
491+
mask_top = (
492+
top / self.number_pixel_rows_for_slide * self.number_pixel_rows_for_mask
493+
)
494+
mask_bottom = (
495+
bottom / self.number_pixel_rows_for_slide * self.number_pixel_rows_for_mask
496+
)
497+
mask_left = (
498+
left
499+
/ self.number_pixel_columns_for_slide
500+
* self.number_pixel_columns_for_mask
501+
)
502+
mask_right = (
503+
right
504+
/ self.number_pixel_columns_for_slide
505+
* self.number_pixel_columns_for_mask
506+
)
507+
cumulative_top_left = self.interpolate_cumulative(mask_top, mask_left)
508+
cumulative_top_right = self.interpolate_cumulative(mask_top, mask_right)
509+
cumulative_bottom_left = self.interpolate_cumulative(mask_bottom, mask_left)
510+
cumulative_bottom_right = self.interpolate_cumulative(mask_bottom, mask_right)
511+
cumulative = (
512+
cumulative_bottom_right
513+
- cumulative_bottom_left
514+
- cumulative_top_right
515+
+ cumulative_top_left
516+
)
517+
if self.mask_threshold > 0:
518+
score = cumulative / (
519+
self.mask_threshold
520+
* (mask_bottom - mask_top)
521+
* (mask_right - mask_left)
522+
)
523+
return score < 0.999999
524+
else:
525+
return cumulative < 0.000001
526+
436527

437528
class TilesByList:
438529
"""Select the tiles supplied by the user. Optionally, select a

0 commit comments

Comments
 (0)