Skip to content

Commit 1110379

Browse files
Created using Colab
1 parent ef026aa commit 1110379

File tree

1 file changed

+55
-32
lines changed

1 file changed

+55
-32
lines changed

notebooks/train_Cellpose-SAM.ipynb

Lines changed: 55 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
{
44
"cell_type": "markdown",
55
"metadata": {
6-
"colab_type": "text",
7-
"id": "view-in-github"
6+
"id": "view-in-github",
7+
"colab_type": "text"
88
},
99
"source": [
1010
"<a href=\"https://colab.research.google.com/github/MouseLand/cellpose/blob/main/notebooks/train_Cellpose-SAM.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
@@ -40,7 +40,9 @@
4040
},
4141
{
4242
"cell_type": "markdown",
43-
"metadata": {},
43+
"metadata": {
44+
"id": "5-Pn33KNqmE6"
45+
},
4446
"source": [
4547
"### Mount your google drive\n",
4648
"\n",
@@ -69,7 +71,9 @@
6971
},
7072
{
7173
"cell_type": "markdown",
72-
"metadata": {},
74+
"metadata": {
75+
"id": "oK-WIju4qmE7"
76+
},
7377
"source": [
7478
"\n",
7579
"Then click on \"Folder\" icon on the Left, press the refresh button. Your Google Drive folder should now be available here as \"gdrive\".\n",
@@ -114,7 +118,9 @@
114118
},
115119
{
116120
"cell_type": "markdown",
117-
"metadata": {},
121+
"metadata": {
122+
"id": "CnKdFgZTqmE9"
123+
},
118124
"source": [
119125
"Check GPU and instantiate model - will download weights."
120126
]
@@ -137,17 +143,19 @@
137143
"io.logger_setup() # run this to get printing of progress\n",
138144
"\n",
139145
"#Check if colab notebook instance has GPU access\n",
140-
"if core.use_gpu()==False: \n",
146+
"if core.use_gpu()==False:\n",
141147
" raise ImportError(\"No GPU access, change your runtime\")\n",
142-
" \n",
148+
"\n",
143149
"model = models.CellposeModel(gpu=True)"
144150
]
145151
},
146152
{
147153
"cell_type": "markdown",
148-
"metadata": {},
154+
"metadata": {
155+
"id": "plEha5EaqmE9"
156+
},
149157
"source": [
150-
"Input directory with your images:"
158+
"Input directory with your images (if you have them, otherwise use sample images):"
151159
]
152160
},
153161
{
@@ -184,17 +192,21 @@
184192
},
185193
{
186194
"cell_type": "markdown",
187-
"metadata": {},
195+
"metadata": {
196+
"id": "0JnV_E_OqmE9"
197+
},
188198
"source": [
189199
"### Sample images (optional)\n",
190200
"\n",
191-
"You can use our sample images instead of mounting your google drive "
201+
"You can use our sample images instead of mounting your google drive"
192202
]
193203
},
194204
{
195205
"cell_type": "code",
196206
"execution_count": null,
197-
"metadata": {},
207+
"metadata": {
208+
"id": "sG96J_V8qmE-"
209+
},
198210
"outputs": [],
199211
"source": [
200212
"from natsort import natsorted\n",
@@ -214,18 +226,22 @@
214226
},
215227
{
216228
"cell_type": "markdown",
217-
"metadata": {},
229+
"metadata": {
230+
"id": "dJFJG-mkqmE-"
231+
},
218232
"source": [
219233
"## Train new model"
220234
]
221235
},
222236
{
223237
"cell_type": "code",
224238
"execution_count": null,
225-
"metadata": {},
239+
"metadata": {
240+
"id": "r0umDFliqmE-"
241+
},
226242
"outputs": [],
227243
"source": [
228-
"from cellpose import train \n",
244+
"from cellpose import train\n",
229245
"\n",
230246
"model_name = \"new_model\"\n",
231247
"\n",
@@ -238,23 +254,24 @@
238254
"# get files\n",
239255
"output = io.load_train_test_data(train_dir, test_dir, mask_filter=masks_ext)\n",
240256
"train_data, train_labels, _, test_data, test_labels, _ = output\n",
257+
"# (not passing test data into function to speed up training)\n",
241258
"\n",
242-
"new_model_path, train_losses, test_losses = train.train_seg(model.net, \n",
243-
" train_data=train_data, \n",
244-
" train_labels=train_labels, \n",
245-
" test_data=test_data,\n",
246-
" test_labels=test_labels,\n",
259+
"new_model_path, train_losses, test_losses = train.train_seg(model.net,\n",
260+
" train_data=train_data,\n",
261+
" train_labels=train_labels,\n",
247262
" batch_size=batch_size,\n",
248263
" n_epochs=n_epochs,\n",
249-
" learning_rate=learning_rate, \n",
250-
" weight_decay=weight_decay, \n",
251-
" nimg_per_epoch=min(8, len(files)),\n",
264+
" learning_rate=learning_rate,\n",
265+
" weight_decay=weight_decay,\n",
266+
" nimg_per_epoch=max(2, len(train_data)), # can change this\n",
252267
" model_name=model_name)\n"
253268
]
254269
},
255270
{
256271
"cell_type": "markdown",
257-
"metadata": {},
272+
"metadata": {
273+
"id": "gj0EdXtcqmE-"
274+
},
258275
"source": [
259276
"## Evaluate on test data (optional)\n",
260277
"\n",
@@ -264,12 +281,14 @@
264281
{
265282
"cell_type": "code",
266283
"execution_count": null,
267-
"metadata": {},
284+
"metadata": {
285+
"id": "Y2Gv4KnSqmE-"
286+
},
268287
"outputs": [],
269288
"source": [
270-
"from cellpose import metrics \n",
289+
"from cellpose import metrics\n",
271290
"\n",
272-
"model = models.CellposeModel(gpu=True, \n",
291+
"model = models.CellposeModel(gpu=True,\n",
273292
" pretrained_model=new_model_path)\n",
274293
"\n",
275294
"# run model on test images\n",
@@ -283,15 +302,19 @@
283302
},
284303
{
285304
"cell_type": "markdown",
286-
"metadata": {},
305+
"metadata": {
306+
"id": "OddRFdtEqmE-"
307+
},
287308
"source": [
288309
"plot masks"
289310
]
290311
},
291312
{
292313
"cell_type": "code",
293314
"execution_count": null,
294-
"metadata": {},
315+
"metadata": {
316+
"id": "9MUrvy5JqmE-"
317+
},
295318
"outputs": [],
296319
"source": [
297320
"plt.figure(figsize=(12,8), dpi=150)\n",
@@ -323,8 +346,8 @@
323346
"metadata": {
324347
"accelerator": "GPU",
325348
"colab": {
326-
"include_colab_link": true,
327-
"provenance": []
349+
"provenance": [],
350+
"include_colab_link": true
328351
},
329352
"kernelspec": {
330353
"display_name": "cellpose",
@@ -346,4 +369,4 @@
346369
},
347370
"nbformat": 4,
348371
"nbformat_minor": 0
349-
}
372+
}

0 commit comments

Comments
 (0)