Conversation
This reverts commit efd9362.
voorhs
left a comment
There was a problem hiding this comment.
очень здорово что все работает, осталось только отшлифовать деталь с тем что не получилось сделать универсальный торч дампер
| with (path / Dumper.containers / "containers.json").open("w") as f: | ||
| json.dump(containers, f, ensure_ascii=False, indent=4) |
There was a problem hiding this comment.
тут могут быть несериализуемые типы, надо либо try except либо убедиться что все типы сериализуемы
| for model_dir in child.iterdir(): | ||
| with (model_dir / "class_info.json").open("r") as f: | ||
| class_info = json.load(f) | ||
| module = __import__(class_info["module"], fromlist=[class_info["name"]]) |
There was a problem hiding this comment.
лучше бы использовать importlib
There was a problem hiding this comment.
по поводу торч моделей в дампере: раз уж мы накладываем ограничение на объекты nn.Module что они должны иметь метод get_config то тогда надо в нашей библиотеке завести класс обертку для всех nn.Module и дать ему абстрактный метод get_config
а раз уж не получилось полноценно реализовать идею с универсальным дампером для торч моделей тогда лучше определять dump/load для каждой модели отдельно указав dump/load в этой абстрактной обертке для nn.Module
| def get_implicit_initialization_params(self) -> dict[str, Any]: | ||
| """Return default params used in initialization.""" | ||
| return { | ||
| "max_seq_length": self.max_seq_length, | ||
| "num_train_epochs": self.num_train_epochs, | ||
| "batch_size": self.batch_size, | ||
| "learning_rate": self.learning_rate, | ||
| "seed": self.seed, | ||
| "report_to": self.report_to, | ||
| "embed_dim": self.embed_dim, | ||
| "kernel_sizes": self.kernel_sizes, | ||
| "num_filters": self.num_filters, | ||
| "dropout": self.dropout | ||
| } |
There was a problem hiding this comment.
тут нужно пустой словарь вернуть, потому что смысл этого метода не в том чтобы вернуть все дефолтные параметры
|
|
||
| def fit(self, utterances: list[str], labels: ListOfLabels) -> None: | ||
| self._validate_task(labels) | ||
| self._multilabel = isinstance(labels[0], (list, np.ndarray)) # noqa: UP038 |
There was a problem hiding this comment.
этот атрибут не нужно устанавливать, он устанавливается в _validate_task
| probs = torch.softmax(outputs, dim=1).cpu().numpy() | ||
| all_probs.append(probs) | ||
|
|
||
| return np.concatenate(all_probs, axis=0) if all_probs else np.array([]) |
There was a problem hiding this comment.
почему тут может быть пустой выход? если вход изначально пустой?
No description provided.