diff --git a/Lecture_TUKR/tanaka/animal.py b/Lecture_TUKR/tanaka/animal.py new file mode 100644 index 00000000..e63ad26b --- /dev/null +++ b/Lecture_TUKR/tanaka/animal.py @@ -0,0 +1,28 @@ +import numpy as np +import os + + +def load_data(retlabel_animal=True, retlabel_feature=False): + datastore_name = 'datastore/animal' + file_name = 'features.txt' + + directory_path = os.path.join(os.path.dirname(__file__), datastore_name) + file_path = os.path.join(directory_path, file_name) + + x = np.loadtxt(file_path) + + return_objects = [x] + + if retlabel_animal: + label_name = 'labels_animal.txt' + label_path = os.path.join(directory_path, label_name) + label_animal = np.genfromtxt(label_path, dtype=str) + return_objects.append(label_animal) + + if retlabel_feature: + label_name = 'labels_feature.txt' + label_path = os.path.join(directory_path, label_name) + label_feature = np.genfromtxt(label_path, dtype=str) + return_objects.append(label_feature) + + return return_objects \ No newline at end of file diff --git a/Lecture_TUKR/tanaka/data_scratch_tanaka.py b/Lecture_TUKR/tanaka/data_scratch_tanaka.py new file mode 100644 index 00000000..8339fd72 --- /dev/null +++ b/Lecture_TUKR/tanaka/data_scratch_tanaka.py @@ -0,0 +1,42 @@ +import numpy as np +import matplotlib.pyplot as plt +from sklearn.preprocessing import MinMaxScaler + +def load_kura_tsom(xsamples, ysamples, missing_rate=None,retz=False): + # z1 = np.random.rand(xsamples)/2 + z1 = np.linspace(-1,1,xsamples) + # z2 = np.random.rand(ysamples)/2 + z2 = np.linspace(-1,1,ysamples) + + z1_repeated, z2_repeated = np.meshgrid(z1,z2) + x1 = z1_repeated + x2 = z2_repeated + x3 = (x1**2-x2**2) + #ノイズを加えたい時はここをいじる,locがガウス分布の平均、scaleが分散,size何個ノイズを作るか + #このノイズを加えることによって三次元空間のデータ点は上下に動く + + x = np.concatenate((x1[:, :, np.newaxis], x2[:, :, np.newaxis], x3[:, :, np.newaxis]), axis=2) + truez = np.concatenate((z1_repeated[:, :, np.newaxis], z2_repeated[:, :, np.newaxis]), axis=2) + # print(x.shape) + + if missing_rate == 0 or missing_rate == None: + if retz: + return x, truez, z1, z2 + else: + return x + +if __name__ == '__main__': + import matplotlib.pyplot as plt + from mpl_toolkits.mplot3d import Axes3D + + xsamples = 10 + ysamples = 10 + + x, truez = load_kura_tsom(xsamples,ysamples,retz=True) + + fig = plt.figure(figsize=[5, 5]) + ax_x = fig.add_subplot(projection='3d') + ax_x.scatter(x[:, :, 0].flatten(), x[:, :, 1].flatten(), x[:, :, 2].flatten(), c=x[:, :, 0].flatten()) + ax_x.set_title('Generated three-dimensional data') + plt.show() + diff --git a/Lecture_TUKR/tanaka/datastore/animal/features.txt b/Lecture_TUKR/tanaka/datastore/animal/features.txt new file mode 100644 index 00000000..6a1806a2 --- /dev/null +++ b/Lecture_TUKR/tanaka/datastore/animal/features.txt @@ -0,0 +1,17 @@ +1.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 1.0 0.0 0.0 0.5 0.5 0.0 0.0 0.0 +0.0 1.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 1.0 0.0 0.0 0.5 0.5 0.0 0.0 0.0 +0.0 1.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 1.0 0.5 0.5 0.0 0.0 0.0 +0.0 1.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 1.0 1.0 1.0 0.5 0.5 0.0 0.0 0.0 +0.0 1.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.3 0.0 0.0 1.0 1.0 0.0 0.5 0.5 0.0 0.0 0.0 +0.0 1.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 1.0 0.0 1.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 +0.0 0.0 1.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 1.0 0.0 1.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 +0.0 0.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 1.0 0.0 1.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 +0.0 1.0 0.0 0.5 0.0 1.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.3 0.7 1.0 0.0 0.0 +0.0 1.0 0.0 0.0 0.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 1.0 0.0 1.0 +0.0 1.0 0.0 1.0 0.0 1.0 1.0 0.0 0.0 0.0 0.0 1.0 1.0 0.0 0.0 0.0 0.0 1.0 1.0 0.0 0.0 +1.0 0.0 0.0 0.5 0.0 1.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 1.0 1.0 +0.0 0.0 1.0 0.5 0.0 1.0 1.0 0.0 0.0 0.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 1.0 0.0 +0.0 0.0 1.0 0.0 0.0 1.0 1.0 0.0 1.0 0.0 0.0 1.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 1.0 0.0 +0.0 0.0 1.0 0.0 0.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 1.0 0.0 0.0 1.0 1.0 0.0 0.0 0.0 0.0 +0.0 0.0 1.0 0.0 0.0 1.0 1.0 1.0 1.0 0.0 1.0 0.0 1.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 +0.0 0.0 1.0 0.0 0.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 1.0 0.0 0.0 0.0 0.0 \ No newline at end of file diff --git a/Lecture_TUKR/tanaka/datastore/animal/labels_animal.txt b/Lecture_TUKR/tanaka/datastore/animal/labels_animal.txt new file mode 100644 index 00000000..45934139 --- /dev/null +++ b/Lecture_TUKR/tanaka/datastore/animal/labels_animal.txt @@ -0,0 +1,17 @@ +dove +cock +duck +w_duck +owl +hawk +eagle +crow +fox +dog +wolf +cat +tiger +lion +horse +zebra +cattle \ No newline at end of file diff --git a/Lecture_TUKR/tanaka/datastore/animal/labels_feature.txt b/Lecture_TUKR/tanaka/datastore/animal/labels_feature.txt new file mode 100644 index 00000000..cb3c869f --- /dev/null +++ b/Lecture_TUKR/tanaka/datastore/animal/labels_feature.txt @@ -0,0 +1,21 @@ +small +medium +large +nocturnality +two_legs +four_legs +hair +hoof +mane +wing +stripe +hunt +run +fly +swim +domestic +herbivorous +carnivore +canidae +felidae +pet \ No newline at end of file diff --git a/Lecture_TUKR/tanaka/tmp.mp4 b/Lecture_TUKR/tanaka/tmp.mp4 new file mode 100644 index 00000000..ab11b3bf Binary files /dev/null and b/Lecture_TUKR/tanaka/tmp.mp4 differ diff --git a/Lecture_TUKR/tanaka/tukr.py b/Lecture_TUKR/tanaka/tukr.py new file mode 100644 index 00000000..44634d86 --- /dev/null +++ b/Lecture_TUKR/tanaka/tukr.py @@ -0,0 +1,175 @@ +import numpy as np +import jax,jaxlib +import jax.numpy as jnp +import tensorflow as tf +from tqdm import tqdm #プログレスバーを表示させてくれる +from sklearn.datasets import load_iris + + + +class TUKR: + def __init__(self, X, nb_samples1, nb_samples2, latent_dim1, latent_dim2, sigma1, sigma2, prior='random', Uinit=None, Vinit=None): + #--------初期値を設定する.--------- + self.X = X + #ここから下は書き換えてね + + if X.ndim == 3: + self.nb_samples1, self.nb_samples2, self.ob_dim = self.X.shape + else: + self.nb_samples1, self.nb_samples2 = self.X.shape + self.ob_dim = 1 + self.X = X[:,:,None] + + self.sigma1 = sigma1 + self.sigma2 = sigma2 + self.latent_dim1 = latent_dim1 + self.latent_dim2 = latent_dim2 + self.alpha = alpha + self.norm = norm + + if Uinit is None: + if prior == 'random': #一様事前分布のとき + self.U = np.random.uniform(-0.1, 0.1, size=(self.nb_samples1, self.latent_dim1)) + #(平均,標準偏差,配列のサイズ) + else: #ガウス事前分布のとき + self.U = np.random.uniform(0, 0.1 * self.sigma1, (self.nb_samples1, self.latent_dim1)) + else: #Zの初期値が与えられた時 + self.U = Uinit + + self.history = {} + + if Vinit is None: + if prior == 'random': # 一様事前分布のとき + self.V = np.random.normal(-0.1, 0.1 , size=(self.nb_samples2, self.latent_dim2)) + # (平均,標準偏差,配列のサイズ) + else: #ガウス事前分布のとき + self.V = np.random.normal(0, 0.1 * self.sigma2, (self.nb_samples2, self.latent_dim2)) + else: # Zの初期値が与えられた時 + self.V = Vinit + + self.history = {} + + def f(self, U, V): #写像の計算 + DistU = jnp.sum((U[:, None, :] - U[None, :, :]) ** 2, axis=2) + DistV = jnp.sum((V[:, None, :] - V[None, :, :]) ** 2, axis=2) + HU = jnp.exp((-1 * DistU) / (2 * (self.sigma1) ** 2)) + HV = jnp.exp((-1 * DistV) / (2 * (self.sigma2) ** 2)) + # GU = jnp.sum(HU, axis=1)[:, None] + # GV = jnp.sum(HV, axis=1)[:, None] + # RU = HU / GU + # RV = HV / GV + f = jnp.einsum('li,kj,ijd->lkd', HU, HV, self.X) + f1 = jnp.einsum('li,kj->lk', HU, HV) + f2 = f1[:, :, None] + return f / f2 + + def ff(self, U, V, epoch): #写像の計算 + DistU = jnp.sum((U[:, None, :] - self.history['u'][epoch][None, :, :]) ** 2, axis=2) + DistV = jnp.sum((V[:, None, :] - self.history['v'][epoch][None, :, :]) ** 2, axis=2) + HU = jnp.exp((-1 * DistU) / (2 * (self.sigma1) ** 2)) + HV = jnp.exp((-1 * DistV) / (2 * (self.sigma2) ** 2)) + # GU = jnp.sum(HU, axis=1)[:, None] + # GV = jnp.sum(HV, axis=1)[:, None] + # RU = HU / GU + # RV = HV / GV + f = jnp.einsum('li,kj,ijd->lkd', HU, HV, self.X) + f1 = jnp.einsum('li,kj->lk', HU, HV) + f2 = f1[:, :, None] + return f / f2 + + def E(self,U,V,X,alpha,norm):#目的関数の計算 + Y = self.f(U,V) + e = jnp.sum((X - Y) ** 2) + r = alpha*(jnp.sum(U**norm)+jnp.sum(V**norm)) + e = e/(self.nb_samples1*self.nb_samples2) + r = r/(self.nb_samples1*self.nb_samples2) + return e + r + + def fit(self, nb_epoch: int, eta: float,alpha,norm): + # 学習過程記録用 + self.history['u'] = np.zeros((nb_epoch, self.nb_samples1, self.latent_dim2)) + self.history['v'] = np.zeros((nb_epoch, self.nb_samples2, self.latent_dim1)) + self.history['f'] = np.zeros((nb_epoch, self.nb_samples1, self.nb_samples2, self.ob_dim)) + self.history['error'] = np.zeros(nb_epoch) + + for epoch in tqdm(range(nb_epoch)): + + # U,Vの更新 + dEdu = jax.grad(self.E, argnums=0)(self.U, self.V, self.X, alpha, norm) / self.nb_samples1 + self.U = self.U - eta * dEdu + + dEdv = jax.grad(self.E, argnums=1)(self.U, self.V, self.X, alpha, norm) / self.nb_samples2 + self.V = self.V - eta * dEdv + + # 学習過程記録用 + self.history['u'][epoch] = self.U + self.history['v'][epoch] = self.V + self.history['f'][epoch] = self.f(self.U, self.V) + self.history['error'][epoch] = self.E(self.U, self.V, self.X, alpha, norm) + + #--------------以下描画用(上の部分が実装できたら実装してね)--------------------- + def calc_approximate_f(self, resolution): #fのメッシュ描画用,resolution:一辺の代表点の数 + nb_epoch = self.history['u'].shape[0] + self.history['y'] = np.zeros((nb_epoch,self.nb_samples1,self.nb_samples2, self.ob_dim)) + for epoch in tqdm(range(nb_epoch)): + Uzeta = self.create_Uzeta(self.history['u'][epoch],resolution) + Vzeta = self.create_Vzeta(self.history['v'][epoch],resolution) + + y = self.ff(Uzeta,Vzeta,epoch) + self.history['y'][epoch] = y + + def create_Uzeta(self, U, resolution): #fのメッシュの描画用に潜在空間に代表点zetaを作る. + Uzeta = np.linspace(np.min(U), np.max(U),self.nb_samples1).reshape(-1,1) + + return Uzeta + + def create_Vzeta(self, V, resolution): # fのメッシュの描画用に潜在空間に代表点zetaを作る. + Vzeta = np.linspace(np.min(V), np.max(V),self.nb_samples2).reshape(-1,1) + + return Vzeta + + +if __name__ == '__main__': + # from Lecture_TUKR.tanaka.animal import load_data + from Lecture_TUKR.tanaka.data_scratch_tanaka import load_kura_tsom + # from Lecture_TUKR.tanaka.data_scratch_tanaka import create_rasen + # from Lecture_TUKR.tanaka.data_scratch_tanaka import create_2d_sin_curve + from visualizer import visualize_history + # from visualizer_animal import visualize_history + + #各種パラメータ変えて遊んでみてね. + epoch = 200 #学習回数 + sigma1 = 0.2 + sigma2 = 0.1 #カーネルの幅 + eta = 50 #学習率 + latent_dim1 = 2 #潜在空間の次元 + latent_dim2 = 2 #潜在空間の次元 + alpha = 0.1 + norm = 2 + seed = 4 + np.random.seed(seed) + + + + #入力データ(詳しくはdata.pyを除いてみると良い) + nb_samples1 = 10 #データ数 + nb_samples2 = 20 #データ数 + # data = load_data(retlabel_animal=True, retlabel_feature=True) + # X = load_iris() + X = load_kura_tsom(nb_samples1,nb_samples2) #鞍型データ ob_dim=3, 真のL=2 + # X = create_rasen(nb_samples) #らせん型データ ob_dim=3, 真のL=1 + # X = create_2d_sin_curve(nb_samples) #sin型データ ob_dim=2, 真のL=1 + + # X = data[0] + # animal_label = data[1] + # feature_label = data[2] + X, truez, z1, z2 = load_kura_tsom(nb_samples1, nb_samples2, retz=True) + allZ =[truez,z1,z2] + tukr = TUKR(X, nb_samples1, nb_samples2, latent_dim1, latent_dim2, sigma1, sigma2, prior='random') + tukr.fit(epoch, eta, alpha, norm) + # visualize_history(X, tukr.history['f'], tukr.history['u'],tukr.history['v'], tukr.history['error'], save_gif=False, filename="tmp") + + #----------描画部分が実装されたらコメントアウト外す---------- + tukr.calc_approximate_f(resolution=10) + # visualize_history(X, tukr.history['y'], tukr.history['u'],tukr.history['v'], tukr.history['error'],animal_label, feature_label, save_gif=False, filename="tmp") + visualize_history(X, tukr.history['y'], tukr.history['u'],tukr.history['v'], tukr.history['error'], save_gif=False, filename="tmp") \ No newline at end of file diff --git a/Lecture_TUKR/tanaka/visualizer.py b/Lecture_TUKR/tanaka/visualizer.py new file mode 100644 index 00000000..72651c28 --- /dev/null +++ b/Lecture_TUKR/tanaka/visualizer.py @@ -0,0 +1,109 @@ +import numpy as np +from matplotlib import pyplot as plt +from matplotlib.animation import FuncAnimation +from sklearn.preprocessing import MinMaxScaler + +STEP = 150 + +def visualize_history(X, Y_history, U_history, V_history, error_history, save_gif=False, filename="tmp", allZ=None): + input_dim, latent_dim1, latent_dim2 = X.shape[2], U_history[0].shape[1], V_history[0].shape[1] + input_projection_type = '3d' if input_dim > 2 else 'rectilinear' + + fig = plt.figure(figsize=(10, 8)) + gs = fig.add_gridspec(3, 3) + input_ax = fig.add_subplot(gs[0:2, 0], projection=input_projection_type) + latent1_ax = fig.add_subplot(gs[0:2, 1], aspect='equal') + latent1_ax.set_facecolor('w') + latent2_ax = fig.add_subplot(gs[0:2, 2], aspect='equal') + latent2_ax.set_facecolor('w') + error_ax = fig.add_subplot(gs[2, :]) + num_epoch = len(Y_history) + + if input_dim == 3 and latent_dim1 == 2 and latent_dim2 == 2: + resolution = int(np.sqrt(Y_history.shape[1])) + if Y_history.shape[1] == resolution ** 2: + Y_history = np.array(Y_history).reshape((num_epoch, resolution, resolution, input_dim)) + + observable_drawer = [None, None, draw_observable_2D, + draw_observable_3D][input_dim] + latent1_drawer = [None, draw_latent_1D, draw_latent_2D][latent_dim1] + latent2_drawer = [None, draw_latent_1D, draw_latent_2D][latent_dim2] + + ani = FuncAnimation( + fig, + update_graph, + frames=num_epoch, # // STEP, + repeat=True, + interval=10, + fargs=(observable_drawer, latent1_drawer, latent2_drawer, X, Y_history, U_history, V_history, error_history, fig, + input_ax, latent1_ax, latent2_ax, error_ax, num_epoch,allZ)) + plt.show() + if save_gif: + ani.save(f"{filename}.mp4", writer='ffmpeg') + + +def update_graph(epoch, observable_drawer, latent1_drawer, latent2_drawer, X, Y_history, + U_history, V_history, error_history, fig, input_ax, latent1_ax, latent2_ax, error_ax, num_epoch, allZ): + fig.suptitle(f"epoch: {epoch}") + input_ax.cla() + # input_ax.view_init(azim=(epoch * 400 / num_epoch), elev=30) + latent1_ax.cla() + latent2_ax.cla() + error_ax.cla() + + Y, U, V= Y_history[epoch], U_history[epoch], V_history[epoch] + (truez, z1, z2) = allZ + mmscaler = MinMaxScaler(feature_range=(0, 1), copy=True) + truez = mmscaler.fit_transform(truez.reshape(-1, 2)) + truez = truez.reshape(z1.shape[0], z2.shape[0], 2) + + z1 = mmscaler.fit_transform(z1[:, None]) + z2 = mmscaler.fit_transform(z2[:, None]) + colormap = np.ones((truez.shape[0], truez.shape[1], 3)) + colormap[:, :, 0] = truez[:, :, 0] + colormap[:, :, 2] = truez[:, :, 1] + colormap = colormap.reshape(-1, 3) + colormap1 = np.ones((z1.shape[0], 3)) + colormap1[:, 0] = z1[:, 0] + + colormap2 = np.ones((z2.shape[0], 3)) + colormap2[:, 2] = z2[:, 0] + + + observable_drawer(input_ax, X, Y, colormap) + latent1_drawer(latent1_ax, U, colormap1) + latent2_drawer(latent2_ax, V, colormap2) + draw_error(error_ax, error_history, epoch) + + +def draw_observable_3D(ax, X, Y, colormap): + ax.scatter(X[:, :, 0], X[:, :, 1], X[:, :, 2], c=colormap) + # ax.set_zlim(-1, 1) + if len(Y.shape) == 3: + ax.plot_wireframe(Y[:, :, 0], Y[:, :, 1], Y[:, :, 2], color='black') + # ax.scatter(Y[:, :, 0], Y[:, :, 1], Y[:, :, 2], color='black') + else: + ax.plot(Y[:,:,0], Y[:,:,1], Y[:,:,2], color='black') +# ax.plot(Y[:, 0], Y[:, 1], Y[:, 2], color='black') +# ax.plot_wireframe(Y[:, :, 0], Y[:, :, 1], Y[:, :, 2], color='black') + + +def draw_observable_2D(ax, X, Y, colormap, label): + ax.scatter(X[:, 0], X[:, 1], c=colormap) + ax.plot(Y[:, 0], Y[:, 1], c='black') + + +def draw_latent_2D(ax, Z, colormap, label): + ax.set_xlim(-1.1, 1.1) + ax.set_ylim(-1.1, 1.1) + ax.scatter(Z[:, 0], Z[:, 1], c=colormap) + + +def draw_latent_1D(ax, Z, colormap, label): + ax.scatter(Z, np.zeros(Z.shape), c=colormap,) + ax.set_ylim(-1, 1) + +def draw_error(ax, error_history, epoch): + ax.set_title("error_function", fontsize=8) + ax.plot(error_history, label='誤差関数') + ax.scatter(epoch, error_history[epoch], s=55, marker="*") diff --git a/Lecture_TUKR/tanaka/visualizer_animal.py b/Lecture_TUKR/tanaka/visualizer_animal.py new file mode 100644 index 00000000..f05261db --- /dev/null +++ b/Lecture_TUKR/tanaka/visualizer_animal.py @@ -0,0 +1,90 @@ +import numpy as np +from matplotlib import pyplot as plt +from matplotlib.animation import FuncAnimation + +STEP = 150 + + +def visualize_history(X, Y_history, U_history, V_history, error_history, animal_label, feature_label, save_gif=False, filename="tmp"): + input_dim, latent_dim1, latent_dim2 = 1, U_history[0].shape[1], V_history[0].shape[1] + input_projection_type = '3d' if input_dim > 2 else 'rectilinear' + + fig = plt.figure(figsize=(10, 8)) + gs = fig.add_gridspec(3, 2) + latent1_ax = fig.add_subplot(gs[0:2, 0], aspect='equal') + latent2_ax = fig.add_subplot(gs[0:2, 1], aspect='equal') #0:2図の大きさ + error_ax = fig.add_subplot(gs[2, :]) + num_epoch = len(Y_history) #マップの配置 + + if input_dim == 3 and latent_dim1 == 2 and latent_dim2 == 2: + resolution = int(np.sqrt(Y_history.shape[1])) + if Y_history.shape[1] == resolution ** 2: + Y_history = np.array(Y_history).reshape((num_epoch, resolution, resolution, input_dim)) + + observable_drawer = [None, None, draw_observable_2D, + draw_observable_3D][input_dim] + latent1_drawer = [None, draw_latent_1D, draw_latent_2D][latent_dim1] + latent2_drawer = [None, draw_latent_1D, draw_latent_2D][latent_dim2] + + ani = FuncAnimation( + fig, + update_graph, + frames=num_epoch, # // STEP, + repeat=True, + interval=10, + fargs=(latent1_drawer, latent2_drawer, X, Y_history, U_history, V_history, error_history, animal_label, feature_label, fig, latent1_ax, latent2_ax, error_ax, num_epoch)) + plt.show() + if save_gif: + ani.save(f"{filename}.mp4", writer='ffmpeg') + + +def update_graph(epoch, latent1_drawer, latent2_drawer, X, Y_history, + U_history, V_history, error_history, animal_label, feature_label,fig, latent1_ax, latent2_ax, error_ax, num_epoch): + fig.suptitle(f"epoch: {epoch}") + # input_ax.view_init(azim=(epoch * 400 / num_epoch), elev=30) + latent1_ax.cla() + latent2_ax.cla() + error_ax.cla() + + Y, U, V= Y_history[epoch], U_history[epoch], V_history[epoch] + colormap1 = X[:, 0] + colormap2 = X[0, :] + + latent1_drawer(latent1_ax,U,animal_label,colormap1) + latent2_drawer(latent2_ax,V,feature_label,colormap2) + draw_error(error_ax, error_history, epoch) + + +def draw_observable_3D(ax, X, Y, colormap): + ax.scatter(X[:, :, 0], X[:, :, 1], X[:, :, 2], c=colormap) + # ax.set_zlim(-1, 1) + if len(Y.shape) == 3: + ax.plot_wireframe(Y[:, :, 0], Y[:, :, 1], Y[:, :, 2], color='black') + # ax.scatter(Y[:, :, 0], Y[:, :, 1], Y[:, :, 2], color='black') + else: + ax.plot(Y[:,:,0], Y[:,:,1], Y[:,:,2], color='black') +# ax.plot(Y[:, 0], Y[:, 1], Y[:, 2], color='black') +# ax.plot_wireframe(Y[:, :, 0], Y[:, :, 1], Y[:, :, 2], color='black') + + +def draw_observable_2D(ax, X, Y, colormap): + ax.scatter(X[:, 0], X[:, 1], c=colormap) + ax.plot(Y[:, 0], Y[:, 1], c='black') + + +def draw_latent_2D(ax, Z, label,colormap): + ax.set_xlim(-1.1, 1.1) + ax.set_ylim(-1.1, 1.1) + for n in range (Z.shape[0]): + ax.text(Z[n, 0], Z[n, 1], label[n]) + ax.scatter(Z[:, 0], Z[:, 1], c=colormap) + + +def draw_latent_1D(ax, Z, colormap): + ax.scatter(Z, np.zeros(Z.shape), c=colormap) + ax.set_ylim(-1, 1) + +def draw_error(ax, error_history, epoch): + ax.set_title("error_function", fontsize=8) + ax.plot(error_history, label='誤差関数') + ax.scatter(epoch, error_history[epoch], s=55, marker="*") diff --git a/Lecture_UKR/tanaka/datasets/animal/features.txt b/Lecture_UKR/tanaka/datasets/animal/features.txt new file mode 100644 index 00000000..6a1806a2 --- /dev/null +++ b/Lecture_UKR/tanaka/datasets/animal/features.txt @@ -0,0 +1,17 @@ +1.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 1.0 0.0 0.0 0.5 0.5 0.0 0.0 0.0 +0.0 1.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 1.0 0.0 0.0 0.5 0.5 0.0 0.0 0.0 +0.0 1.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 1.0 0.5 0.5 0.0 0.0 0.0 +0.0 1.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 1.0 1.0 1.0 0.5 0.5 0.0 0.0 0.0 +0.0 1.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.3 0.0 0.0 1.0 1.0 0.0 0.5 0.5 0.0 0.0 0.0 +0.0 1.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 1.0 0.0 1.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 +0.0 0.0 1.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 1.0 0.0 1.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 +0.0 0.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 1.0 0.0 1.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 +0.0 1.0 0.0 0.5 0.0 1.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.3 0.7 1.0 0.0 0.0 +0.0 1.0 0.0 0.0 0.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 1.0 0.0 1.0 +0.0 1.0 0.0 1.0 0.0 1.0 1.0 0.0 0.0 0.0 0.0 1.0 1.0 0.0 0.0 0.0 0.0 1.0 1.0 0.0 0.0 +1.0 0.0 0.0 0.5 0.0 1.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 1.0 1.0 +0.0 0.0 1.0 0.5 0.0 1.0 1.0 0.0 0.0 0.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 1.0 0.0 +0.0 0.0 1.0 0.0 0.0 1.0 1.0 0.0 1.0 0.0 0.0 1.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 1.0 0.0 +0.0 0.0 1.0 0.0 0.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 1.0 0.0 0.0 1.0 1.0 0.0 0.0 0.0 0.0 +0.0 0.0 1.0 0.0 0.0 1.0 1.0 1.0 1.0 0.0 1.0 0.0 1.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 +0.0 0.0 1.0 0.0 0.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 1.0 0.0 0.0 0.0 0.0 \ No newline at end of file diff --git a/Lecture_UKR/tanaka/datasets/animal/labels_animal.txt b/Lecture_UKR/tanaka/datasets/animal/labels_animal.txt new file mode 100644 index 00000000..45934139 --- /dev/null +++ b/Lecture_UKR/tanaka/datasets/animal/labels_animal.txt @@ -0,0 +1,17 @@ +dove +cock +duck +w_duck +owl +hawk +eagle +crow +fox +dog +wolf +cat +tiger +lion +horse +zebra +cattle \ No newline at end of file diff --git a/Lecture_UKR/tanaka/datasets/animal/labels_feature.txt b/Lecture_UKR/tanaka/datasets/animal/labels_feature.txt new file mode 100644 index 00000000..cb3c869f --- /dev/null +++ b/Lecture_UKR/tanaka/datasets/animal/labels_feature.txt @@ -0,0 +1,21 @@ +small +medium +large +nocturnality +two_legs +four_legs +hair +hoof +mane +wing +stripe +hunt +run +fly +swim +domestic +herbivorous +carnivore +canidae +felidae +pet \ No newline at end of file diff --git a/Lecture_UKR/tanaka/shippai.gif b/Lecture_UKR/tanaka/shippai.gif deleted file mode 100644 index 7e9477a0..00000000 Binary files a/Lecture_UKR/tanaka/shippai.gif and /dev/null differ diff --git a/poetry.lock b/poetry.lock index 3377965e..73a504c8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -927,14 +927,6 @@ numpy = ">=1.7" docs = ["sphinx (==1.2.3)", "sphinxcontrib-napoleon", "sphinx-rtd-theme", "numpydoc"] tests = ["pytest", "pytest-cov", "pytest-pep8"] -[[package]] -name = "package" -version = "0.1.1" -description = "package is a package to package your package" -category = "main" -optional = false -python-versions = "*" - [[package]] name = "packaging" version = "21.3" @@ -1195,23 +1187,23 @@ use_chardet_on_py3 = ["chardet (>=3.0.2,<5)"] [[package]] name = "scikit-learn" -version = "1.0.2" +version = "1.1.1" description = "A set of python modules for machine learning and data mining" category = "main" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" [package.dependencies] -joblib = ">=0.11" -numpy = ">=1.14.6" -scipy = ">=1.1.0" +joblib = ">=1.0.0" +numpy = ">=1.17.3" +scipy = ">=1.3.2" threadpoolctl = ">=2.0.0" [package.extras] -benchmark = ["matplotlib (>=2.2.3)", "pandas (>=0.25.0)", "memory-profiler (>=0.57.0)"] -docs = ["matplotlib (>=2.2.3)", "scikit-image (>=0.14.5)", "pandas (>=0.25.0)", "seaborn (>=0.9.0)", "memory-profiler (>=0.57.0)", "sphinx (>=4.0.1)", "sphinx-gallery (>=0.7.0)", "numpydoc (>=1.0.0)", "Pillow (>=7.1.2)", "sphinx-prompt (>=1.3.0)", "sphinxext-opengraph (>=0.4.2)"] -examples = ["matplotlib (>=2.2.3)", "scikit-image (>=0.14.5)", "pandas (>=0.25.0)", "seaborn (>=0.9.0)"] -tests = ["matplotlib (>=2.2.3)", "scikit-image (>=0.14.5)", "pandas (>=0.25.0)", "pytest (>=5.0.1)", "pytest-cov (>=2.9.0)", "flake8 (>=3.8.2)", "black (>=21.6b0)", "mypy (>=0.770)", "pyamg (>=4.0.0)"] +benchmark = ["matplotlib (>=3.1.2)", "pandas (>=1.0.5)", "memory-profiler (>=0.57.0)"] +docs = ["matplotlib (>=3.1.2)", "scikit-image (>=0.14.5)", "pandas (>=1.0.5)", "seaborn (>=0.9.0)", "memory-profiler (>=0.57.0)", "sphinx (>=4.0.1)", "sphinx-gallery (>=0.7.0)", "numpydoc (>=1.2.0)", "Pillow (>=7.1.2)", "sphinx-prompt (>=1.3.0)", "sphinxext-opengraph (>=0.4.2)"] +examples = ["matplotlib (>=3.1.2)", "scikit-image (>=0.14.5)", "pandas (>=1.0.5)", "seaborn (>=0.9.0)"] +tests = ["matplotlib (>=3.1.2)", "scikit-image (>=0.14.5)", "pandas (>=1.0.5)", "pytest (>=5.0.1)", "pytest-cov (>=2.9.0)", "flake8 (>=3.8.2)", "black (>=22.3.0)", "mypy (>=0.770)", "pyamg (>=4.0.0)", "numpydoc (>=1.2.0)"] [[package]] name = "scipy" @@ -1467,12 +1459,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest- [metadata] lock-version = "1.1" python-versions = "^3.8" -<<<<<<< HEAD -content-hash = "e5c013e46fd9644a82dbbde6592752cb8a7054e498e588547b2616ddc77e0486" -======= -content-hash = "257a017d2319bcb488debf9faed6d22168e21e3a75b59ee78ccb2a9fc7fc9ba6" ->>>>>>> main [metadata.files] absl-py = [ @@ -1957,9 +1944,6 @@ opt-einsum = [ {file = "opt_einsum-3.3.0-py3-none-any.whl", hash = "sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147"}, {file = "opt_einsum-3.3.0.tar.gz", hash = "sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549"}, ] -package = [ - {file = "package-0.1.1.tar.gz", hash = "sha256:01eee19a56a936bd63222f0a3c531fcdba37b5ad1bd833b960d62fb960b4955e"}, -] packaging = [ {file = "packaging-21.3-py3-none-any.whl", hash = "sha256:ef103e05f519cdc783ae24ea4e2e0f508a9c99b2d4969652eed6a2e1ea5bd522"}, {file = "packaging-21.3.tar.gz", hash = "sha256:dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb"}, @@ -2200,38 +2184,24 @@ requests = [ {file = "requests-2.27.1.tar.gz", hash = "sha256:68d7c56fd5a8999887728ef304a6d12edc7be74f1cfa47714fc8b414525c9a61"}, ] scikit-learn = [ - {file = "scikit-learn-1.0.2.tar.gz", hash = "sha256:b5870959a5484b614f26d31ca4c17524b1b0317522199dc985c3b4256e030767"}, - {file = "scikit_learn-1.0.2-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:da3c84694ff693b5b3194d8752ccf935a665b8b5edc33a283122f4273ca3e687"}, - {file = "scikit_learn-1.0.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:75307d9ea39236cad7eea87143155eea24d48f93f3a2f9389c817f7019f00705"}, - {file = "scikit_learn-1.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f14517e174bd7332f1cca2c959e704696a5e0ba246eb8763e6c24876d8710049"}, - {file = "scikit_learn-1.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d9aac97e57c196206179f674f09bc6bffcd0284e2ba95b7fe0b402ac3f986023"}, - {file = "scikit_learn-1.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:d93d4c28370aea8a7cbf6015e8a669cd5d69f856cc2aa44e7a590fb805bb5583"}, - {file = "scikit_learn-1.0.2-cp37-cp37m-macosx_10_13_x86_64.whl", hash = "sha256:85260fb430b795d806251dd3bb05e6f48cdc777ac31f2bcf2bc8bbed3270a8f5"}, - {file = "scikit_learn-1.0.2-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:a053a6a527c87c5c4fa7bf1ab2556fa16d8345cf99b6c5a19030a4a7cd8fd2c0"}, - {file = "scikit_learn-1.0.2-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:245c9b5a67445f6f044411e16a93a554edc1efdcce94d3fc0bc6a4b9ac30b752"}, - {file = "scikit_learn-1.0.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:158faf30684c92a78e12da19c73feff9641a928a8024b4fa5ec11d583f3d8a87"}, - {file = "scikit_learn-1.0.2-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:08ef968f6b72033c16c479c966bf37ccd49b06ea91b765e1cc27afefe723920b"}, - {file = "scikit_learn-1.0.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:16455ace947d8d9e5391435c2977178d0ff03a261571e67f627c8fee0f9d431a"}, - {file = "scikit_learn-1.0.2-cp37-cp37m-win32.whl", hash = "sha256:2f3b453e0b149898577e301d27e098dfe1a36943f7bb0ad704d1e548efc3b448"}, - {file = "scikit_learn-1.0.2-cp37-cp37m-win_amd64.whl", hash = "sha256:46f431ec59dead665e1370314dbebc99ead05e1c0a9df42f22d6a0e00044820f"}, - {file = "scikit_learn-1.0.2-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:ff3fa8ea0e09e38677762afc6e14cad77b5e125b0ea70c9bba1992f02c93b028"}, - {file = "scikit_learn-1.0.2-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:9369b030e155f8188743eb4893ac17a27f81d28a884af460870c7c072f114243"}, - {file = "scikit_learn-1.0.2-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:7d6b2475f1c23a698b48515217eb26b45a6598c7b1840ba23b3c5acece658dbb"}, - {file = "scikit_learn-1.0.2-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:285db0352e635b9e3392b0b426bc48c3b485512d3b4ac3c7a44ec2a2ba061e66"}, - {file = "scikit_learn-1.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5cb33fe1dc6f73dc19e67b264dbb5dde2a0539b986435fdd78ed978c14654830"}, - {file = "scikit_learn-1.0.2-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b1391d1a6e2268485a63c3073111fe3ba6ec5145fc957481cfd0652be571226d"}, - {file = "scikit_learn-1.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc3744dabc56b50bec73624aeca02e0def06b03cb287de26836e730659c5d29c"}, - {file = "scikit_learn-1.0.2-cp38-cp38-win32.whl", hash = "sha256:a999c9f02ff9570c783069f1074f06fe7386ec65b84c983db5aeb8144356a355"}, - {file = "scikit_learn-1.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:7626a34eabbf370a638f32d1a3ad50526844ba58d63e3ab81ba91e2a7c6d037e"}, - {file = "scikit_learn-1.0.2-cp39-cp39-macosx_10_13_x86_64.whl", hash = "sha256:a90b60048f9ffdd962d2ad2fb16367a87ac34d76e02550968719eb7b5716fd10"}, - {file = "scikit_learn-1.0.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:7a93c1292799620df90348800d5ac06f3794c1316ca247525fa31169f6d25855"}, - {file = "scikit_learn-1.0.2-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:eabceab574f471de0b0eb3f2ecf2eee9f10b3106570481d007ed1c84ebf6d6a1"}, - {file = "scikit_learn-1.0.2-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:55f2f3a8414e14fbee03782f9fe16cca0f141d639d2b1c1a36779fa069e1db57"}, - {file = "scikit_learn-1.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:80095a1e4b93bd33261ef03b9bc86d6db649f988ea4dbcf7110d0cded8d7213d"}, - {file = "scikit_learn-1.0.2-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fa38a1b9b38ae1fad2863eff5e0d69608567453fdfc850c992e6e47eb764e846"}, - {file = "scikit_learn-1.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff746a69ff2ef25f62b36338c615dd15954ddc3ab8e73530237dd73235e76d62"}, - {file = "scikit_learn-1.0.2-cp39-cp39-win32.whl", hash = "sha256:e174242caecb11e4abf169342641778f68e1bfaba80cd18acd6bc84286b9a534"}, - {file = "scikit_learn-1.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:b54a62c6e318ddbfa7d22c383466d38d2ee770ebdb5ddb668d56a099f6eaf75f"}, + {file = "scikit-learn-1.1.1.tar.gz", hash = "sha256:3e77b71e8e644f86c8b5be7f1c285ef597de4c384961389ee3e9ca36c445b256"}, + {file = "scikit_learn-1.1.1-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:102f51797cd8944bf44a038d106848ddf2804f2c1edf7aea45fba81a4fdc4d80"}, + {file = "scikit_learn-1.1.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:723cdb278b1fa57a55f68945bc4e501a2f12abe82f76e8d21e1806cbdbef6fc5"}, + {file = "scikit_learn-1.1.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:33cf061ed0b79d647a3e4c3f6c52c412172836718a7cd4d11c1318d083300133"}, + {file = "scikit_learn-1.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47464c110eaa9ed9d1fe108cb403510878c3d3a40f110618d2a19b2190a3e35c"}, + {file = "scikit_learn-1.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:542ccd2592fe7ad31f5c85fed3a3deb3e252383960a85e4b49a629353fffaba4"}, + {file = "scikit_learn-1.1.1-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:3be10d8d325821ca366d4fe7083d87c40768f842f54371a9c908d97c45da16fc"}, + {file = "scikit_learn-1.1.1-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:b2db720e13e697d912a87c1a51194e6fb085dc6d8323caa5ca51369ca6948f78"}, + {file = "scikit_learn-1.1.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e851f8874398dcd50d1e174e810e9331563d189356e945b3271c0e19ee6f4d6f"}, + {file = "scikit_learn-1.1.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b928869072366dc138762fe0929e7dc88413f8a469aebc6a64adc10a9226180c"}, + {file = "scikit_learn-1.1.1-cp38-cp38-win32.whl", hash = "sha256:e9d228ced1214d67904f26fb820c8abbea12b2889cd4aa8cda20a4ca0ed781c1"}, + {file = "scikit_learn-1.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:f2d5b5d6e87d482e17696a7bfa03fe9515fdfe27e462a4ad37f3d7774a5e2fd6"}, + {file = "scikit_learn-1.1.1-cp39-cp39-macosx_10_13_x86_64.whl", hash = "sha256:0403ad13f283e27d43b0ad875f187ec7f5d964903d92d1ed06c51439560ecea0"}, + {file = "scikit_learn-1.1.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:8fe80df08f5b9cee5dd008eccc672e543976198d790c07e5337f7dfb67eaac05"}, + {file = "scikit_learn-1.1.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8ff56d07b9507fbe07ca0f4e5c8f3e171f74a429f998da03e308166251316b34"}, + {file = "scikit_learn-1.1.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c2dad2bfc502344b869d4a3f4aa7271b2a5f4fe41f7328f404844c51612e2c58"}, + {file = "scikit_learn-1.1.1-cp39-cp39-win32.whl", hash = "sha256:22145b60fef02e597a8e7f061ebc7c51739215f11ce7fcd2ca9af22c31aa9f86"}, + {file = "scikit_learn-1.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:45c0f6ae523353f1d99b85469d746f9c497410adff5ba8b24423705b6956a86e"}, ] scipy = [ {file = "scipy-1.6.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:a15a1f3fc0abff33e792d6049161b7795909b40b97c6cc2934ed54384017ab76"}, diff --git a/pyproject.toml b/pyproject.toml index 4b5a7fe4..5a9d63b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,14 +13,6 @@ jupyter = "^1.0.0" jupyterlab = "^3.3.2" jupyterlab-git = "^0.36.0" matplotlib = "^3.5.1" -<<<<<<< HEAD -scikit-learn = "^1.0.2" -tqdm = "^4.64.0" -======= -tqdm = "^4.64.0" - -torch = "^1.11.0" ->>>>>>> main [tool.poetry.dev-dependencies]