diff --git a/sf3d/models/camera.py b/sf3d/models/camera.py index 7fed714..9917d20 100644 --- a/sf3d/models/camera.py +++ b/sf3d/models/camera.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import List +from typing import List, Dict, Any import torch import torch.nn as nn @@ -19,7 +19,7 @@ class Config(BaseModule.Config): def configure(self) -> None: self.linear = nn.Linear(self.cfg.in_channels, self.cfg.out_channels) - def forward(self, **kwargs): + def forward(self, **kwargs: Dict[str, torch.Tensor]) -> torch.Tensor: cond_tensors = [] for cond_name in self.cfg.conditions: assert cond_name in kwargs @@ -30,3 +30,4 @@ def forward(self, **kwargs): assert cond_tensor.shape[-1] == self.cfg.in_channels embedding = self.linear(cond_tensor) return embedding +