Skip to content

Commit e15251b

Browse files
Merge pull request #15 from computational-cell-analytics/dtype-support
Convert inputs to uint8
2 parents ea84438 + 6c4c210 commit e15251b

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

micro_sam/util.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,13 @@ def get_sam_model(device=None, model_type="vit_h", checkpoint_path=None, return_
105105

106106

107107
def _to_image(input_):
108+
# we require the input to be uint8
109+
if input_.dtype != np.dtype("uint8"):
110+
# first normalize the input to [0, 1]
111+
input_ = input_.astype("float32") - input_.min()
112+
input_ = input_ / input_.max()
113+
# then bring to [0, 255] and cast to uint8
114+
input_ = (input_ * 255).astype("uint8")
108115
if input_.ndim == 2:
109116
image = np.concatenate([input_[..., None]] * 3, axis=-1)
110117
elif input_.ndim == 3 and input_.shape[-1] == 3:

0 commit comments

Comments
 (0)