|
226 | 226 | "import torch\n", |
227 | 227 | "import torch.nn as nn\n", |
228 | 228 | "import torch.optim as optim\n", |
229 | | - "import torchmetrics\n", |
230 | 229 | "\n", |
231 | 230 | "\n", |
232 | 231 | "class SegmentationExperiment(pl.LightningModule):\n", |
|
236 | 235 | " multitask_loss: Dict[str, nn.Module],\n", |
237 | 236 | " optimizer: optim.Optimizer,\n", |
238 | 237 | " scheduler: optim.lr_scheduler._LRScheduler,\n", |
239 | | - " branch_metrics: Optional[Dict[str, List[torchmetrics.Metric]]] = None,\n", |
240 | 238 | " optimizer_kwargs: Optional[Dict[str, float]] = None,\n", |
241 | 239 | " scheduler_kwargs: Optional[Dict[str, float]] = None,\n", |
242 | 240 | " **kwargs,\n", |
|
253 | 251 | " self.optimizer_kwargs = optimizer_kwargs or {}\n", |
254 | 252 | " self.scheduler_kwargs = scheduler_kwargs or {}\n", |
255 | 253 | "\n", |
256 | | - " self.branch_metrics = branch_metrics\n", |
257 | 254 | " self.criterion = multitask_loss\n", |
258 | 255 | "\n", |
259 | 256 | " self._validate_branch_args()\n", |
|
0 commit comments