|
63 | 63 | "cosine": 2, |
64 | 64 | "hellinger": 1, |
65 | 65 | "jaccard": 1, |
| 66 | + "bit_jaccard": 1, |
66 | 67 | "dice": 1, |
67 | 68 | } |
68 | 69 |
|
@@ -2351,8 +2352,10 @@ def fit(self, X, y=None, force_all_finite=True): |
2351 | 2352 | - 'allow-nan': accepts only np.nan and pd.NA values in array. |
2352 | 2353 | Values cannot be infinite. |
2353 | 2354 | """ |
2354 | | - |
2355 | | - X = check_array(X, dtype=np.float32, accept_sparse="csr", order="C", force_all_finite=force_all_finite) |
| 2355 | + if self.metric in ("bit_hamming", "bit_jaccard"): |
| 2356 | + X = check_array(X, dtype=np.uint8, order="C", force_all_finite=force_all_finite) |
| 2357 | + else: |
| 2358 | + X = check_array(X, dtype=np.float32, accept_sparse="csr", order="C", force_all_finite=force_all_finite) |
2356 | 2359 | self._raw_data = X |
2357 | 2360 |
|
2358 | 2361 | # Handle all the optional arguments, setting default |
@@ -2926,7 +2929,10 @@ def transform(self, X, force_all_finite=True): |
2926 | 2929 | "Transform unavailable when model was fit with only a single data sample." |
2927 | 2930 | ) |
2928 | 2931 | # If we just have the original input then short circuit things |
2929 | | - X = check_array(X, dtype=np.float32, accept_sparse="csr", order="C", force_all_finite=force_all_finite) |
| 2932 | + if self.metric in ("bit_hamming", "bit_jaccard"): |
| 2933 | + X = check_array(X, dtype=np.uint8, order="C", force_all_finite=force_all_finite) |
| 2934 | + else: |
| 2935 | + X = check_array(X, dtype=np.float32, accept_sparse="csr", order="C", force_all_finite=force_all_finite) |
2930 | 2936 | x_hash = joblib.hash(X) |
2931 | 2937 | if x_hash == self._input_hash: |
2932 | 2938 | if self.transform_mode == "embedding": |
@@ -3297,7 +3303,10 @@ def _output_dist_only(x, y, *kwds): |
3297 | 3303 | return inv_transformed_points |
3298 | 3304 |
|
3299 | 3305 | def update(self, X, force_all_finite=True): |
3300 | | - X = check_array(X, dtype=np.float32, accept_sparse="csr", order="C", force_all_finite=force_all_finite) |
| 3306 | + if self.metric in ("bit_hamming", "bit_jaccard"): |
| 3307 | + X = check_array(X, dtype=np.uint8, order="C", force_all_finite=force_all_finite) |
| 3308 | + else: |
| 3309 | + X = check_array(X, dtype=np.float32, accept_sparse="csr", order="C", force_all_finite=force_all_finite) |
3301 | 3310 | random_state = check_random_state(self.transform_seed) |
3302 | 3311 | rng_state = random_state.randint(INT32_MIN, INT32_MAX, 3).astype(np.int64) |
3303 | 3312 |
|
|
0 commit comments