-
-
Notifications
You must be signed in to change notification settings - Fork 48.7k
Add Quantum k-Means Clustering Implementation #11664
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
7d1d891
027549b
e2d0e50
89f5f80
c46141d
caf89c4
facfce2
3655742
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import cirq | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
from sklearn.datasets import make_blobs | ||
from sklearn.preprocessing import MinMaxScaler | ||
|
||
def generate_data(n_samples=100, n_features=2, n_clusters=2): | ||
RahulPatnaik marked this conversation as resolved.
Show resolved
Hide resolved
|
||
data, labels = make_blobs(n_samples=n_samples, centers=n_clusters, n_features=n_features, random_state=42) | ||
return MinMaxScaler().fit_transform(data), labels | ||
|
||
def quantum_distance(point1, point2): | ||
RahulPatnaik marked this conversation as resolved.
Show resolved
Hide resolved
RahulPatnaik marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Quantum circuit explanation: | ||
1. Use a single qubit to encode the distance between two points. | ||
2. Apply Ry rotation based on the normalized Euclidean distance. | ||
3. Measure the qubit to get a probabilistic distance metric. | ||
The probability of measuring |1> correlates with the distance between points. | ||
""" | ||
qubit = cirq.LineQubit(0) | ||
diff = np.clip(np.linalg.norm(point1 - point2), 0, 1) | ||
theta = 2 * np.arcsin(diff) | ||
|
||
circuit = cirq.Circuit( | ||
cirq.ry(theta)(qubit), | ||
cirq.measure(qubit, key='result') | ||
) | ||
|
||
result = cirq.Simulator().run(circuit, repetitions=1000) | ||
return result.histogram(key='result').get(1, 0) / 1000 | ||
|
||
def initialize_centroids(data, k): | ||
|
||
return data[np.random.choice(len(data), k, replace=False)] | ||
|
||
def assign_clusters(data, centroids): | ||
RahulPatnaik marked this conversation as resolved.
Show resolved
Hide resolved
RahulPatnaik marked this conversation as resolved.
Show resolved
Hide resolved
|
||
clusters = [[] for _ in range(len(centroids))] | ||
for point in data: | ||
closest = min(range(len(centroids)), key=lambda i: quantum_distance(point, centroids[i])) | ||
|
||
clusters[closest].append(point) | ||
return clusters | ||
|
||
def recompute_centroids(clusters): | ||
RahulPatnaik marked this conversation as resolved.
Show resolved
Hide resolved
RahulPatnaik marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return np.array([np.mean(cluster, axis=0) for cluster in clusters if cluster]) | ||
|
||
def quantum_kmeans(data, k, max_iters=10): | ||
RahulPatnaik marked this conversation as resolved.
Show resolved
Hide resolved
RahulPatnaik marked this conversation as resolved.
Show resolved
Hide resolved
|
||
centroids = initialize_centroids(data, k) | ||
|
||
for _ in range(max_iters): | ||
clusters = assign_clusters(data, centroids) | ||
new_centroids = recompute_centroids(clusters) | ||
if np.allclose(new_centroids, centroids): | ||
break | ||
centroids = new_centroids | ||
|
||
return centroids, clusters | ||
|
||
# Main execution | ||
n_samples, n_clusters = 10, 2 | ||
data, labels = generate_data(n_samples, n_clusters=n_clusters) | ||
|
||
plt.figure(figsize=(12, 5)) | ||
|
||
plt.subplot(121) | ||
plt.scatter(data[:, 0], data[:, 1], c=labels) | ||
plt.title("Generated Data") | ||
|
||
final_centroids, final_clusters = quantum_kmeans(data, n_clusters) | ||
|
||
plt.subplot(122) | ||
for i, cluster in enumerate(final_clusters): | ||
cluster = np.array(cluster) | ||
plt.scatter(cluster[:, 0], cluster[:, 1], label=f'Cluster {i+1}') | ||
plt.scatter(final_centroids[:, 0], final_centroids[:, 1], color='red', marker='x', s=200, linewidths=3, label='Centroids') | ||
plt.title("Quantum k-Means Clustering with Cirq") | ||
plt.legend() | ||
|
||
plt.tight_layout() | ||
plt.show() | ||
|
||
print(f"Final Centroids:\n{final_centroids}") | ||
Uh oh!
There was an error while loading. Please reload this page.