Skip to content
This repository was archived by the owner on Feb 27, 2026. It is now read-only.

Commit 9b34912

Browse files
committed
Update missing dependency
1 parent 32ee3b3 commit 9b34912

File tree

2 files changed

+22
-3
lines changed

2 files changed

+22
-3
lines changed

pycave/bayes/gmm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
import torch
44
import torch.nn as nn
55
import torch.distributions as dist
6+
from sklearn.cluster import KMeans
67
import pyblaze.nn as xnn
78
from pyblaze.utils.stdio import ProgressBar
8-
from pyblaze.utils.torch import to_one_hot
9-
from sklearn.cluster import KMeans
10-
from .utils import log_normal, log_responsibilities, max_likeli_means, max_likeli_covars
9+
from .utils import log_normal, log_responsibilities, max_likeli_means, max_likeli_covars, \
10+
to_one_hot
1111

1212
class GMMConfig(xnn.Config):
1313
"""

pycave/bayes/utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,22 @@ def power_iteration(A, eps=1e-7, max_iterations=100):
193193
break
194194

195195
return v
196+
197+
198+
def to_one_hot(X, n):
199+
"""
200+
Creates a one-hot matrix from a set of indices.
201+
202+
Parameters:
203+
-----------
204+
- X: torch.Tensor [N, D]
205+
The indices to convert into one-hot vectors.
206+
- n: int
207+
The number of entries in the one-hot vectors.
208+
209+
Returns:
210+
--------
211+
- torch.Tensor [N, D, n]
212+
The one-hot matrix.
213+
"""
214+
return torch.eye(n, device=X.device)[X]

0 commit comments

Comments
 (0)