File tree Expand file tree Collapse file tree 2 files changed +12
-5
lines changed
Expand file tree Collapse file tree 2 files changed +12
-5
lines changed Original file line number Diff line number Diff line change 66from .logger import info
77
88
9- def get_embeddings (use_gpu = False , images = None ):
9+ def get_embeddings (use_gpu = False , images = None , model = 'resnet-18' ):
1010 """
1111 This Python function initializes an Img2Vec object, runs it on either GPU or CPU, and retrieves
1212 image embeddings.
13+ :param use_gpu: The `use_gpu` parameter is a boolean that specifies whether to use GPU or CPU.
14+ :param images: The `images` parameter is a list of image paths to be used for generating embeddings.
15+ :param model: The `model` parameter is a string that specifies the model to use for generating.
16+ For available models, see https://github.com/christiansafka/img2vec
17+ :return: The function `get_embeddings` returns the embeddings of the images as np.ndarray.
1318 """
1419
1520 info (f"Img2Vec is running on { 'GPU' if use_gpu else 'CPU' } ..." )
16- img2vec = Img2Vec (cuda = use_gpu )
17-
21+ img2vec = Img2Vec (cuda = use_gpu , model = model )
22+ print ( f"Using model: { model } " )
1823 embeddings = img2vec .get_vec (images , tensor = False )
1924 return embeddings
2025
Original file line number Diff line number Diff line change @@ -44,21 +44,23 @@ def read(self, folder_path):
4444 self .image_paths = read_images_from_directory (folder_path )
4545 self .images = read_with_pil (self .image_paths )
4646
47- def calculate (self , pca = True , iter = 10 ):
47+ def calculate (self , pca = True , iter = 10 , model = "resnet-18" ):
4848 """
4949 The function calculates embeddings, performs PCA, and applies K-means clustering to the
5050 embeddings. It will not perform these operations if no images have been read.
5151
5252 :param pca: The `pca` parameter is a boolean that specifies whether to perform PCA or not. Default is True
5353 :param iter: The `iter` parameter is an integer that specifies the number of iterations for the KMeans algorithm. Default is 10.
54+ :param model: The `model` parameter is a string that specifies the model to use for generating embeddings. Default is 'resnet-18'.
55+ For available models, see https://github.com/christiansafka/img2vec
5456 """
5557
5658 if not self .images :
5759 raise ValueError (
5860 "The images list can not be empty. Please call the read method before calculating."
5961 )
6062
61- self .embeddings = get_embeddings (use_gpu = self .use_gpu , images = self .images )
63+ self .embeddings = get_embeddings (use_gpu = self .use_gpu , images = self .images , model = model )
6264 if pca :
6365 self .pca_embeddings = calculate_pca (self .embeddings , self .pca_dim )
6466 self .centroid , self .labels , self .counts = calculate_kmeans (
You can’t perform that action at this time.
0 commit comments