Skip to content

Commit 32a82e7

Browse files
committed
完成20-26
1 parent 4c1f935 commit 32a82e7

File tree

7 files changed

+148
-61
lines changed

7 files changed

+148
-61
lines changed

exercises/20_function_template/main.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
#include "../exercise.h"
22

33
// READ: 函数模板 <https://zh.cppreference.com/w/cpp/language/function_template>
4-
// 函数模板定义
5-
// template<typename T>
6-
// T plus(T a, T b) {
7-
// return a + b;
8-
// }
4+
// 函数模板--`template<typename T>`,
5+
// - 函数模板就像是一个“模具”或“蓝图”。
6+
// - 不需要为 int、float 或 double 分别编写逻辑相同的函数
7+
// - // T1 和 T2 可以是不同类型`template <typename T1, typename T2>`
98
// TODO: 将这个函数模板化
109
// int plus(int a, int b) {
1110
// return a + b;

exercises/21_runtime_datatype/main.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,27 @@ struct TaggedUnion {
1818
};
1919

2020
// TODO: 将这个函数模板化用于 sigmoid_dyn
21-
float sigmoid(float x) {
21+
// float sigmoid(float x) {
22+
// return 1 / (1 + std::exp(-x));
23+
// }
24+
template<typename T>
25+
T sigmoid(T x){
2226
return 1 / (1 + std::exp(-x));
2327
}
2428

2529
TaggedUnion sigmoid_dyn(TaggedUnion x) {
2630
TaggedUnion ans{x.type};
2731
// TODO: 根据 type 调用 sigmoid
32+
// 让函数能同时处理 float 和 double
33+
switch (x.type)
34+
{
35+
case DataType::Float:
36+
ans.f = sigmoid(x.f);
37+
break;
38+
case DataType::Double:
39+
ans.d = sigmoid(x.d);
40+
break;
41+
}
2842
return ans;
2943
}
3044

exercises/22_class_template/main.cpp

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#include "../exercise.h"
22
#include <cstring>
33
// READ: 类模板 <https://zh.cppreference.com/w/cpp/language/class_template>
4-
4+
// 类模板(Class Template) 是泛型编程的核心工具。它允许我们编写一个通用的类,而不必为每种数据类型都重写一遍代码。
55
template<class T>
66
struct Tensor4D {
77
unsigned int shape[4];
@@ -10,6 +10,10 @@ struct Tensor4D {
1010
Tensor4D(unsigned int const shape_[4], T const *data_) {
1111
unsigned int size = 1;
1212
// TODO: 填入正确的 shape 并计算 size
13+
for (int i = 0; i < 4; i++) {
14+
shape[i] = shape_[i];
15+
size *= shape[i];
16+
}
1317
data = new T[size];
1418
std::memcpy(data, data_, size * sizeof(T));
1519
}
@@ -28,10 +32,47 @@ struct Tensor4D {
2832
// 则 `this` 与 `others` 相加时,3 个形状为 `[1, 2, 1, 4]` 的子张量各自与 `others` 对应项相加。
2933
Tensor4D &operator+=(Tensor4D const &others) {
3034
// TODO: 实现单向广播的加法
35+
// 1.预先计算others 的 strides (步长),用于将 (n,c,h,w) 坐标转换为线性索引
36+
// 线性索引 = n * stride_0 + c * stride_1 + h * stride_2 + w * stride_3
37+
unsigned int o_stride3=1;
38+
unsigned int o_stride2=others.shape[3];
39+
unsigned int o_stride1=o_stride2*others.shape[2];
40+
unsigned int o_stride0=o_stride1*others.shape[1];
41+
//// 使用指针直接遍历 this->data,避免重复计算 this 的线性索引
42+
T* current_ptr = this->data;
43+
// 4层循环遍历 this 的所有维度
44+
for (unsigned int n = 0; n < shape[0]; ++n) {
45+
// 如果 others 在该维度长度为1,则索引固定为0(广播),否则跟随 n
46+
unsigned int n_idx = (others.shape[0] == 1) ? 0 : n;
47+
48+
for (unsigned int c = 0; c < shape[1]; ++c) {
49+
unsigned int c_idx = (others.shape[1] == 1) ? 0 : c;
50+
51+
for (unsigned int h = 0; h < shape[2]; ++h) {
52+
unsigned int h_idx = (others.shape[2] == 1) ? 0 : h;
53+
54+
for (unsigned int w = 0; w < shape[3]; ++w) {
55+
unsigned int w_idx = (others.shape[3] == 1) ? 0 : w;
56+
// 计算 others 中的线性偏移量
57+
unsigned int others_offset =
58+
n_idx * o_stride0 +
59+
c_idx * o_stride1 +
60+
h_idx * o_stride2 +
61+
w_idx * o_stride3;
62+
// 执行加法
63+
*current_ptr += others.data[others_offset];
64+
65+
// 移动到 this 的下一个元素
66+
++current_ptr;
67+
}
68+
}
69+
}
70+
}
3171
return *this;
3272
}
3373
};
3474

75+
3576
// ---- 不要修改以下代码 ----
3677
int main(int argc, char **argv) {
3778
{

exercises/23_template_const/main.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@ struct Tensor {
1111
Tensor(unsigned int const shape_[N]) {
1212
unsigned int size = 1;
1313
// TODO: 填入正确的 shape 并计算 size
14+
for (unsigned int i = 0; i < N; i++)
15+
{
16+
shape[i] = shape_[i];
17+
size *= shape[i];
18+
}
1419
data = new T[size];
1520
std::memset(data, 0, size * sizeof(T));
1621
}
@@ -35,6 +40,7 @@ struct Tensor {
3540
for (unsigned int i = 0; i < N; ++i) {
3641
ASSERT(indices[i] < shape[i], "Invalid index");
3742
// TODO: 计算 index
43+
index = index * shape[i] + indices[i];
3844
}
3945
return index;
4046
}

exercises/24_std_array/main.cpp

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,43 @@
44

55
// READ: std::array <https://zh.cppreference.com/w/cpp/container/array>
66

7+
// std::array是对 C 语言风格数组(如 int arr[5])的现代化封装。
8+
// - 既保留了原生数组的高性能(零开销)
9+
// - 又提供了类似 STL 容器(如 std::vector)的安全性和便利性接口。
10+
// 1.基本语法,需要包含头文件 <array>。
11+
// - std::array<int, 5> myArray = {1, 2, 3, 4, 5};
12+
// - 第一个模板参数 int 是数据类型,第二个模板参数 5 是数组大小(必须是编译期常量)。
13+
// 2.为什么要使用 std::array?
14+
// - std::array 是一个对象,包含长度消息和数据,不会退化为指针
15+
// - std::array 支持直接赋值,只要类型和大小相同
16+
// - std::array 安全性高,提供了 .at() 方法,会进行边界检查。如果越界,会抛出异常
17+
// - std::array STL兼容性强,提供了迭代器(begin(), end()),可以完美配合 std::sort, std::for_each 等标准算法使用。
18+
// 3. 主要特性
19+
// - 内存分配: 与 C 数组一样,数据通常存储在栈(Stack)上(除非它是全局变量或被 new 出来的)。这意味着没有动态内存分配(堆内存)的开销,速度极快。
20+
// - 固定大小: 大小在编译时必须确定,运行时不能改变(不能像 vector 那样 push_back)。
21+
// - 零开销: 在优化开启的情况下,std::array 的性能与 C 风格数组完全一致。
22+
23+
// std::memcmp (Memory Compare) 是用来比较两块内存区域的内容是否完全一致的函数
724
// TODO: 将下列 `?` 替换为正确的代码
825
int main(int argc, char **argv) {
926
{
1027
std::array<int, 5> arr{{1, 2, 3, 4, 5}};
11-
ASSERT(arr.size() == ?, "Fill in the correct value.");
12-
ASSERT(sizeof(arr) == ?, "Fill in the correct value.");
28+
ASSERT(arr.size() == 5, "Fill in the correct value.");//数组大小
29+
ASSERT(sizeof(arr) == 20, "Fill in the correct value.");//int 的大小为 4
1330
int ans[]{1, 2, 3, 4, 5};
14-
ASSERT(std::memcmp(arr.?, ans, ?) == 0, "Fill in the correct values.");
31+
//比较arr.data 和 ans,长度为 sizeof(ans)
32+
ASSERT(std::memcmp(arr.data(), ans, sizeof(ans)) == 0, "Fill in the correct values.");
1533
}
1634
{
1735
std::array<double, 8> arr;
18-
ASSERT(arr.size() == ?, "Fill in the correct value.");
19-
ASSERT(sizeof(arr) == ?, "Fill in the correct value.");
36+
ASSERT(arr.size() == 8, "Fill in the correct value.");
37+
ASSERT(sizeof(arr) == 64, "Fill in the correct value.");//double 的大小为 8
2038
}
2139
{
2240
std::array<char, 21> arr{"Hello, InfiniTensor!"};
23-
ASSERT(arr.size() == ?, "Fill in the correct value.");
24-
ASSERT(sizeof(arr) == ?, "Fill in the correct value.");
25-
ASSERT(std::strcmp(arr.?, "Hello, InfiniTensor!") == 0, "Fill in the correct value.");
41+
ASSERT(arr.size() == 21, "Fill in the correct value.");
42+
ASSERT(sizeof(arr) == 21, "Fill in the correct value.");//char 的大小为 1
43+
ASSERT(std::strcmp(arr.data(), "Hello, InfiniTensor!") == 0, "Fill in the correct value.");
2644
}
2745
return 0;
2846
}

exercises/25_std_vector/main.cpp

Lines changed: 47 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,86 +3,95 @@
33
#include <vector>
44

55
// READ: std::vector <https://zh.cppreference.com/w/cpp/container/vector>
6-
6+
// vector 是一种序列容器,它允许你在运行时动态地插入和删除元素。需要使用头文件 <vector>。
77
// TODO: 将下列 `?` 替换为正确的代码
88
int main(int argc, char **argv) {
99
{
1010
std::vector<int> vec{1, 2, 3, 4, 5};
11-
ASSERT(vec.size() == ?, "Fill in the correct value.");
11+
ASSERT(vec.size() == 5, "Fill in the correct value.");
1212
// THINK: `std::vector` 的大小是什么意思?与什么有关?
13-
ASSERT(sizeof(vec) == ?, "Fill in the correct value.");
13+
// // 在 64 位系统上,sizeof(vector) 通常是 24 (3个指针)
14+
ASSERT(sizeof(vec) == 24, "Fill in the correct value.");
1415
int ans[]{1, 2, 3, 4, 5};
15-
ASSERT(std::memcmp(vec.?, ans, sizeof(ans)) == 0, "Fill in the correct values.");
16+
ASSERT(std::memcmp(vec.data(), ans, sizeof(ans)) == 0, "Fill in the correct values.");
1617
}
1718
{
1819
std::vector<double> vec{1, 2, 3, 4, 5};
1920
{
20-
ASSERT(vec.size() == ?, "Fill in the correct value.");
21-
ASSERT(sizeof(vec) == ?, "Fill in the correct value.");
21+
ASSERT(vec.size() == 5, "Fill in the correct value.");
22+
ASSERT(sizeof(vec) == 24, "Fill in the correct value.");
2223
double ans[]{1, 2, 3, 4, 5};
23-
ASSERT(std::memcmp(vec.?, ans, sizeof(ans)) == 0, "Fill in the correct values.");
24+
ASSERT(std::memcmp(vec.data(), ans, sizeof(ans)) == 0, "Fill in the correct values.");
2425
}
2526
{
26-
vec.push_back(6);
27-
ASSERT(vec.size() == ?, "Fill in the correct value.");
28-
ASSERT(sizeof(vec) == ?, "Fill in the correct value.");
29-
vec.pop_back();
30-
ASSERT(vec.size() == ?, "Fill in the correct value.");
31-
ASSERT(sizeof(vec) == ?, "Fill in the correct value.");
27+
vec.push_back(6);// 入栈
28+
ASSERT(vec.size() == 6, "Fill in the correct value.");
29+
ASSERT(sizeof(vec) == 24, "Fill in the correct value.");
30+
vec.pop_back();// 出栈
31+
ASSERT(vec.size() == 5, "Fill in the correct value.");
32+
ASSERT(sizeof(vec) == 24, "Fill in the correct value.");
3233
}
3334
{
3435
vec[4] = 6;
35-
ASSERT(vec[0] == ?, "Fill in the correct value.");
36-
ASSERT(vec[1] == ?, "Fill in the correct value.");
37-
ASSERT(vec[2] == ?, "Fill in the correct value.");
38-
ASSERT(vec[3] == ?, "Fill in the correct value.");
39-
ASSERT(vec[4] == ?, "Fill in the correct value.");
36+
ASSERT(vec[0] == 1, "Fill in the correct value.");
37+
ASSERT(vec[1] == 2, "Fill in the correct value.");
38+
ASSERT(vec[2] == 3, "Fill in the correct value.");
39+
ASSERT(vec[3] == 4, "Fill in the correct value.");
40+
ASSERT(vec[4] == 6, "Fill in the correct value.");
4041
}
4142
{
4243
// THINK: `std::vector` 插入删除的时间复杂度是什么?
43-
vec.insert(?, 1.5);
44+
// insert 插入,在中间或头部插入会导致插入点之后的元素全部向后移动,时间复杂度为 $O(n)$。
45+
// - insert 的第一个参数必须是“迭代器”(Iterator),不能是下标。
46+
// - 插入单个元素 vec.insert(it+index, element)。;
47+
// - 插入多个相同的元素 vec.insert(it+index, count, element)。;
48+
49+
// auto it=vec.begin();
50+
// vec.insert(it+1, 1.5);
51+
vec.insert(vec.begin() + 1, 1.5);
4452
ASSERT((vec == std::vector<double>{1, 1.5, 2, 3, 4, 6}), "Make this assertion pass.");
45-
vec.erase(?);
53+
vec.erase(vec.begin() + 3);
4654
ASSERT((vec == std::vector<double>{1, 1.5, 2, 4, 6}), "Make this assertion pass.");
4755
}
4856
{
4957
vec.shrink_to_fit();
50-
ASSERT(vec.capacity() == ?, "Fill in the correct value.");
58+
ASSERT(vec.capacity() == 5, "Fill in the correct value.");
5159
vec.clear();
52-
ASSERT(vec.empty(), "`vec` is empty now.");
53-
ASSERT(vec.size() == ?, "Fill in the correct value.");
54-
ASSERT(vec.capacity() == ?, "Fill in the correct value.");
60+
ASSERT(vec.empty(), "`vec` is empty now.");// empty不改变capacity
61+
ASSERT(vec.size() == 0, "Fill in the correct value.");
62+
ASSERT(vec.capacity() == 5, "Fill in the correct value.");
5563
}
5664
}
5765
{
58-
std::vector<char> vec(?, ?); // TODO: 调用正确的构造函数
66+
std::vector<char> vec(48, 'z'); // TODO: 调用正确的构造函数
5967
ASSERT(vec[0] == 'z', "Make this assertion pass.");
6068
ASSERT(vec[47] == 'z', "Make this assertion pass.");
6169
ASSERT(vec.size() == 48, "Make this assertion pass.");
62-
ASSERT(sizeof(vec) == ?, "Fill in the correct value.");
70+
ASSERT(sizeof(vec) == 24, "Fill in the correct value.");
6371
{
64-
auto capacity = vec.capacity();
72+
// resize改变大小(size)和容量(capacity)
6573
vec.resize(16);
66-
ASSERT(vec.size() == ?, "Fill in the correct value.");
67-
ASSERT(vec.capacity() == ?, "Fill in a correct identifier.");
74+
ASSERT(vec.size() == 16, "Fill in the correct value.");
75+
ASSERT(vec.capacity() == 48, "Fill in a correct identifier.");
6876
}
6977
{
78+
//reserve预留空间,只改变capacity
7079
vec.reserve(256);
71-
ASSERT(vec.size() == ?, "Fill in the correct value.");
72-
ASSERT(vec.capacity() == ?, "Fill in the correct value.");
80+
ASSERT(vec.size() == 16, "Fill in the correct value.");
81+
ASSERT(vec.capacity() == 256, "Fill in the correct value.");
7382
}
7483
{
7584
vec.push_back('a');
7685
vec.push_back('b');
7786
vec.push_back('c');
7887
vec.push_back('d');
79-
ASSERT(vec.size() == ?, "Fill in the correct value.");
80-
ASSERT(vec.capacity() == ?, "Fill in the correct value.");
81-
ASSERT(vec[15] == ?, "Fill in the correct value.");
82-
ASSERT(vec[?] == 'a', "Fill in the correct value.");
83-
ASSERT(vec[?] == 'b', "Fill in the correct value.");
84-
ASSERT(vec[?] == 'c', "Fill in the correct value.");
85-
ASSERT(vec[?] == 'd', "Fill in the correct value.");
88+
ASSERT(vec.size() == 20, "Fill in the correct value.");
89+
ASSERT(vec.capacity() == 256, "Fill in the correct value.");
90+
ASSERT(vec[15] == 'z', "Fill in the correct value.");
91+
ASSERT(vec[16] == 'a', "Fill in the correct value.");
92+
ASSERT(vec[17] == 'b', "Fill in the correct value.");
93+
ASSERT(vec[18] == 'c', "Fill in the correct value.");
94+
ASSERT(vec[19] == 'd', "Fill in the correct value.");
8695
}
8796
}
8897

exercises/26_std_vector_bool/main.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,29 +6,29 @@
66

77
// TODO: 将下列 `?` 替换为正确的代码
88
int main(int argc, char **argv) {
9-
std::vector<bool> vec(?, ?);// TODO: 正确调用构造函数
9+
std::vector<bool> vec(100, true);// TODO: 正确调用构造函数
1010
ASSERT(vec[0], "Make this assertion pass.");
1111
ASSERT(vec[99], "Make this assertion pass.");
1212
ASSERT(vec.size() == 100, "Make this assertion pass.");
1313
// NOTICE: 平台相关!注意 CI:Ubuntu 上的值。
1414
std::cout << "sizeof(std::vector<bool>) = " << sizeof(std::vector<bool>) << std::endl;
15-
ASSERT(sizeof(vec) == ?, "Fill in the correct value.");
15+
ASSERT(sizeof(vec) == 40, "Fill in the correct value.");
1616
{
1717
vec[20] = false;
18-
ASSERT(?vec[20], "Fill in `vec[20]` or `!vec[20]`.");
18+
ASSERT(!vec[20], "Fill in `vec[20]` or `!vec[20]`.");
1919
}
2020
{
2121
vec.push_back(false);
22-
ASSERT(vec.size() == ?, "Fill in the correct value.");
23-
ASSERT(?vec[100], "Fill in `vec[100]` or `!vec[100]`.");
22+
ASSERT(vec.size() == 101, "Fill in the correct value.");
23+
ASSERT(!vec[100], "Fill in `vec[100]` or `!vec[100]`.");
2424
}
2525
{
2626
auto ref = vec[30];
27-
ASSERT(?ref, "Fill in `ref` or `!ref`");
27+
ASSERT(ref, "Fill in `ref` or `!ref`");
2828
ref = false;
29-
ASSERT(?ref, "Fill in `ref` or `!ref`");
29+
ASSERT(!ref, "Fill in `ref` or `!ref`");
3030
// THINK: WHAT and WHY?
31-
ASSERT(?vec[30], "Fill in `vec[30]` or `!vec[30]`.");
31+
ASSERT(!vec[30], "Fill in `vec[30]` or `!vec[30]`.");
3232
}
3333
return 0;
3434
}

0 commit comments

Comments
 (0)