Skip to content

Commit 28ad3de

Browse files
📝 Update K-Means clustering documentation with interactive animation and enhanced PyTorch section
- Replaced static Python example with an interactive K-Means animation using React for better visualization of the clustering process. - Expanded the PyTorch section to include detailed explanations of CUDA, cuDNN, and the CUDA Toolkit, providing essential information for users on GPU acceleration. - Improved overall structure and clarity of the documentation to facilitate understanding of K-Means and its implementation in PyTorch.
1 parent a71b1c6 commit 28ad3de

File tree

2 files changed

+182
-27
lines changed

2 files changed

+182
-27
lines changed

docs/docs/机器学习/传统算法/K均值算法.md

Lines changed: 155 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -46,31 +46,162 @@ $
4646

4747
以此类推
4848

49-
```python showLineNumbers
50-
from sklearn.cluster import KMeans
51-
import numpy as np
52-
53-
# 创建一些示例数据
54-
X = np.array([[1, 2], [2, 3], [2, 5], [3, 2], [3, 3], [4, 5]])
55-
56-
# 创建K均值模型
57-
k = 2 # 指定要分为的簇的数量
58-
model = KMeans(n_clusters=k)
59-
60-
# 拟合模型
61-
# .fit()方法用于训练模型,即让模型从数据中学习
62-
model.fit(X)
63-
64-
# 获取簇中心点
65-
cluster_centers = model.cluster_centers_
66-
67-
# 预测每个样本所属的簇
68-
labels = model.labels_
69-
70-
print("簇中心点:", cluster_centers)
71-
print("样本所属簇:", labels)
72-
49+
<details>
50+
<summary>点击查看动画</summary>
51+
``` jsx live
52+
function KMeansAnimation() {
53+
const gridSize = 10;
54+
55+
const [dataPoints, setDataPoints] = React.useState([]);
56+
const [centroids, setCentroids] = React.useState([
57+
{ x: 0, y: 0 },
58+
{ x: 5, y: 5 }
59+
]);
60+
const [step, setStep] = React.useState(0);
61+
const [iteration, setIteration] = React.useState(0);
62+
63+
React.useEffect(() => {
64+
const generateAllGridPoints = () => {
65+
const points = [];
66+
for (let i = 0; i < gridSize; i++) {
67+
for (let j = 0; j < gridSize; j++) {
68+
points.push({
69+
x: i,
70+
y: j,
71+
cluster: null
72+
});
73+
}
74+
}
75+
return points;
76+
};
77+
78+
setDataPoints(generateAllGridPoints());
79+
}, []);
80+
81+
const distance = (point1, point2) => {
82+
return Math.sqrt(Math.pow(point1.x - point2.x, 2) + Math.pow(point1.y - point2.y, 2));
83+
};
84+
85+
React.useEffect(() => {
86+
if (dataPoints.length === 0) return;
87+
88+
const timer = setTimeout(() => {
89+
if (step === 0) {
90+
const newDataPoints = dataPoints.map(point => {
91+
const dist1 = distance(point, centroids[0]);
92+
const dist2 = distance(point, centroids[1]);
93+
return {
94+
...point,
95+
cluster: dist1 <= dist2 ? 0 : 1
96+
};
97+
});
98+
setDataPoints(newDataPoints);
99+
setStep(1);
100+
} else if (step === 1) {
101+
const cluster0Points = dataPoints.filter(p => p.cluster === 0);
102+
const cluster1Points = dataPoints.filter(p => p.cluster === 1);
103+
104+
if (cluster0Points.length > 0 && cluster1Points.length > 0) {
105+
const newX0 = Math.round(cluster0Points.reduce((sum, p) => sum + p.x, 0) / cluster0Points.length);
106+
const newY0 = Math.round(cluster0Points.reduce((sum, p) => sum + p.y, 0) / cluster0Points.length);
107+
108+
const newX1 = Math.round(cluster1Points.reduce((sum, p) => sum + p.x, 0) / cluster1Points.length);
109+
const newY1 = Math.round(cluster1Points.reduce((sum, p) => sum + p.y, 0) / cluster1Points.length);
110+
111+
setCentroids([
112+
{ x: newX0, y: newY0 },
113+
{ x: newX1, y: newY1 }
114+
]);
115+
}
116+
117+
setStep(0);
118+
setIteration(prev => prev + 1);
119+
}
120+
}, 1000);
121+
122+
return () => clearTimeout(timer);
123+
}, [step, dataPoints, centroids]);
124+
125+
const renderGrid = () => {
126+
const grid = [];
127+
128+
for (let y = 0; y < gridSize; y++) {
129+
for (let x = 0; x < gridSize; x++) {
130+
const pointAtPosition = dataPoints.find(p => p.x === x && p.y === y);
131+
132+
const isCentroid0 = centroids[0].x === x && centroids[0].y === y;
133+
const isCentroid1 = centroids[1].x === x && centroids[1].y === y;
134+
135+
let cellStyle = {
136+
width: '32px',
137+
height: '32px',
138+
border: '1px solid #cbd5e0',
139+
display: 'flex',
140+
alignItems: 'center',
141+
justifyContent: 'center'
142+
};
143+
144+
if (pointAtPosition) {
145+
if (pointAtPosition.cluster === 0) {
146+
cellStyle.backgroundColor = '#9ae6b4';
147+
} else if (pointAtPosition.cluster === 1) {
148+
cellStyle.backgroundColor = '#90cdf4';
149+
}
150+
}
151+
152+
if (isCentroid0) {
153+
cellStyle.backgroundColor = '#276749';
154+
} else if (isCentroid1) {
155+
cellStyle.backgroundColor = '#2b6cb0';
156+
}
157+
158+
grid.push(
159+
<div key={`${x}-${y}`} style={cellStyle}></div>
160+
);
161+
}
162+
}
163+
164+
return grid;
165+
};
166+
167+
return (
168+
<div style={{display: 'flex', flexDirection: 'column', alignItems: 'center', padding: '16px'}}>
169+
<h2 style={{fontSize: '1.25rem', fontWeight: 'bold', marginBottom: '16px'}}>K-Means 聚类算法可视化</h2>
170+
<div style={{marginBottom: '16px'}}>迭代次数: {iteration}</div>
171+
<div style={{
172+
display: 'grid',
173+
gridTemplateColumns: 'repeat(10, 1fr)',
174+
gap: '4px',
175+
marginBottom: '16px'
176+
}}>
177+
{renderGrid()}
178+
</div>
179+
<div style={{marginTop: '16px', display: 'flex', gap: '24px'}}>
180+
<div style={{display: 'flex', alignItems: 'center'}}>
181+
<div style={{width: '16px', height: '16px', backgroundColor: '#9ae6b4', marginRight: '8px'}}></div>
182+
<span>1数据点</span>
183+
</div>
184+
<div style={{display: 'flex', alignItems: 'center'}}>
185+
<div style={{width: '16px', height: '16px', backgroundColor: '#276749', marginRight: '8px'}}></div>
186+
<span>1中心点</span>
187+
</div>
188+
<div style={{display: 'flex', alignItems: 'center'}}>
189+
<div style={{width: '16px', height: '16px', backgroundColor: '#90cdf4', marginRight: '8px'}}></div>
190+
<span>2数据点</span>
191+
</div>
192+
<div style={{display: 'flex', alignItems: 'center'}}>
193+
<div style={{width: '16px', height: '16px', backgroundColor: '#2b6cb0', marginRight: '8px'}}></div>
194+
<span>2中心点</span>
195+
</div>
196+
</div>
197+
</div>
198+
);
199+
}
200+
201+
export default KMeansAnimation;
73202
```
203+
</details>
204+
74205

75206
### 简单示例
76207

docs/docs/机器学习/神经网络/index.md

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@ sidebar_position: 3
33
title: 🚧神经网络入门
44
---
55

6-
## PyTorch
7-
8-
### skorch
6+
## skorch
97

108
至此,我们已经学习了 scikit-learn 的很多算法。完成了传统机器学习的任务。
119

@@ -118,6 +116,32 @@ for i in range(3):
118116
接下来,你可以使用 skorch 来复现更多之前的项目,同时熟悉Pytorch的用法。接下来我们会开始使用pytorch来完成更多复杂和有趣的任务。
119117

120118

119+
## PyTorch
120+
121+
PyTorch 可以利用计算加速设备(例如GPU、NPU),为了达成这一目的,PyTorch 的安装会绑定对应的cuda版本,PyTorch 使用 cuda 的接口来操作底层硬件。
122+
123+
:::info
124+
125+
**CUDA**:NVIDIA 专为自家 GPU 设计的 C++ 并行计算框架,其运行依赖于 NVIDIA 显卡驱动程序。它允许开发者利用 GPU 强大的并行计算能力加速各类计算密集型任务。
126+
127+
**cuDNN**:专门为深度学习计算优化的高性能神经网络库,提供了高度优化的实现,用于常见深度学习操作如卷积、池化、归一化等。
128+
129+
**CUDA Toolkit (NVIDIA 官方版)**:完整的 CUDA 开发环境,包含:
130+
- NVIDIA 显卡驱动程序
131+
- 完整的 CUDA 开发工具链(编译器、IDE、调试器等)
132+
- 各种 CUDA 加速库及其头文件
133+
- 文档和示例代码
134+
135+
**CUDA Toolkit (PyTorch 版)**:精简版 CUDA 工具包,主要包含:
136+
- 运行 CUDA 功能所需的核心动态链接库
137+
- 不包含驱动程序、开发工具及完整文档
138+
- 专为支持 PyTorch 等框架的 CUDA 功能而设计
139+
140+
**NVCC**:NVIDIA CUDA 编译器,是 CUDA Toolkit 的核心组件,负责将 CUDA 代码编译为可在 NVIDIA GPU 上执行的二进制代码。
141+
:::
142+
143+
144+
121145
### PyTorch 数据集
122146

123147
PyTorch 也提供了一些内置的数据集类来加载常用的数据集,如图像、文本等。此外,你也可以使用第三方库来加载自定义的数据集。

0 commit comments

Comments
 (0)