|
12 | 12 | from skimage.filters import gaussian |
13 | 13 | from tifffile import imread as tfl_imread |
14 | 14 | from tqdm import tqdm |
| 15 | +import importlib.util |
15 | 16 |
|
16 | 17 | """ |
17 | 18 | utils.py |
18 | 19 | ==================================== |
19 | 20 | Definitions of utility functions and variables |
20 | 21 | """ |
21 | 22 |
|
22 | | -################## |
23 | | -################## |
24 | | -# dev util |
25 | | -def ENABLE_TEST_MODE(): |
26 | | - path = Path(os.path.expanduser("~")) |
27 | | - # print(path) |
28 | | - print("TEST MODE ENABLED, DEV ONLY") |
29 | | - if path == Path("C:/Users/Cyril"): |
30 | | - return True |
31 | | - return False |
32 | | - |
33 | | - |
34 | | -################## |
35 | | -################## |
36 | | - |
37 | | - |
38 | 23 | def normalize_x(image): |
39 | 24 | """Normalizes the values of an image array to be between [-1;1] rather than [0;255] |
40 | 25 |
|
@@ -996,3 +981,64 @@ def merge_imgs(imgs, original_image_shape): |
996 | 981 |
|
997 | 982 | print(merged_imgs.shape) |
998 | 983 | return merged_imgs |
| 984 | + |
| 985 | + |
| 986 | +def read_plainconfig(configname): |
| 987 | + """ |
| 988 | + This code is adapted from DeepLabCut with permission from MWMathis |
| 989 | + """ |
| 990 | + if not os.path.exists(configname): |
| 991 | + raise FileNotFoundError( |
| 992 | + f"Config {configname} is not found. Please make sure that the file exists." |
| 993 | + ) |
| 994 | + with open(configname) as file: |
| 995 | + return YAML().load(file) |
| 996 | + |
| 997 | +def DownloadModel(modelname, target_dir): |
| 998 | + """ |
| 999 | + Downloads a specific pretained model. |
| 1000 | + This code is adapted from DeepLabCut with permission from MWMathis |
| 1001 | + """ |
| 1002 | + import urllib.request |
| 1003 | + import tarfile |
| 1004 | + from tqdm import tqdm |
| 1005 | + |
| 1006 | + def show_progress(count, block_size, total_size): |
| 1007 | + pbar.update(block_size) |
| 1008 | + |
| 1009 | + def tarfilenamecutting(tarf): |
| 1010 | + """' auxfun to extract folder path |
| 1011 | + ie. /xyz-trainsetxyshufflez/ |
| 1012 | + """ |
| 1013 | + for memberid, member in enumerate(tarf.getmembers()): |
| 1014 | + if memberid == 0: |
| 1015 | + parent = str(member.path) |
| 1016 | + l = len(parent) + 1 |
| 1017 | + if member.path.startswith(parent): |
| 1018 | + member.path = member.path[l:] |
| 1019 | + yield member |
| 1020 | + #TODO: fix error in line 1021; |
| 1021 | + cellseg3d_path = os.path.split(importlib.util.find_spec("napari-cellseg3d").origin)[0] |
| 1022 | + neturls = read_plainconfig(os.path.join(cellseg3d_path,"models","pretrained","pretrained_model_urls.yaml",)) |
| 1023 | + |
| 1024 | + if modelname in neturls.keys(): |
| 1025 | + url = neturls[modelname] |
| 1026 | + response = urllib.request.urlopen(url) |
| 1027 | + print( |
| 1028 | + "Downloading the model from the M.W. Mathis Lab server {}....".format( |
| 1029 | + url |
| 1030 | + ) |
| 1031 | + ) |
| 1032 | + total_size = int(response.getheader("Content-Length")) |
| 1033 | + pbar = tqdm(unit="B", total=total_size, position=0) |
| 1034 | + filename, _ = urllib.request.urlretrieve(url, reporthook=show_progress) |
| 1035 | + with tarfile.open(filename, mode="r:gz") as tar: |
| 1036 | + tar.extractall(target_dir, members=tarfilenamecutting(tar)) |
| 1037 | + else: |
| 1038 | + models = [ |
| 1039 | + fn |
| 1040 | + for fn in neturls.keys() |
| 1041 | + if "VNet_" not in fn and "SegResNet" not in fn and "TRAILMAP_" not in fn |
| 1042 | + ] |
| 1043 | + print("Model does not exist: ", modelname) |
| 1044 | + #print("Pick one of the following: ", models) |
0 commit comments