Skip to content

Commit 6cf16d9

Browse files
committed
提交了项目代码
1 parent f2acc1d commit 6cf16d9

File tree

3 files changed

+106
-4
lines changed

3 files changed

+106
-4
lines changed

examples/CNN_UTS/main.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
22

33
import hydra
4-
import matplotlib.pyplot as plt
54
import numpy as np
65
import paddle
76
import tqdm
@@ -193,6 +192,9 @@ def evaluate(cfg):
193192
"""
194193
完整的评估流程,包括模型加载、推理、指标计算、可视化和集成预测
195194
"""
195+
import matplotlib.pyplot as plt
196+
from sklearn.model_selection import StratifiedGroupKFold
197+
196198
set_seed(cfg.seed)
197199
device = device2str(cfg.device)
198200
n_splits = cfg.train.n_splits

examples/CNN_UTS/readme.md

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
├── Saved_Output/ # 保存预测结果与可视化
1515
├── requirements.txt # 依赖库
1616
└── readme.md # 项目说明文档
17-
17+
```
1818
1919
## 环境依赖
2020
@@ -83,8 +83,15 @@ python main.py mode=eval
8383
- 配置化超参数与数据路径
8484
- 自动保存/加载模型与预测结果
8585
- 多种可视化与统计指标输出
86+
如需引用本项目复现的方法,请参考原论文:
8687

87-
## 联系方式
88+
plaintext
89+
@article{lai2025predicting,
90+
title={Predicting the Strength of Composites with Computer Vision Using Small Experimental Datasets},
91+
author={Lai, Po-Hao and Gomez, Enrique D and Vogt, Bryan D and Reinhart, Wesley F},
92+
journal={ACS Materials Letters},
93+
year={2025},
94+
publisher={American Chemical Society}
95+
}
8896

89-
如有问题请联系:[email protected]
9097

ppsci/arch/resnet.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import paddle
2+
import paddle.nn as nn
3+
4+
from ppsci.arch import base
5+
6+
7+
class ResNetBlock(nn.Layer):
8+
def __init__(self, in_channels, out_channels, stride=1):
9+
super().__init__()
10+
self.conv1 = nn.Conv2D(in_channels, out_channels, 3, stride, 1)
11+
self.bn1 = nn.BatchNorm2D(out_channels)
12+
self.relu = nn.ReLU()
13+
self.conv2 = nn.Conv2D(out_channels, out_channels, 3, 1, 1)
14+
self.bn2 = nn.BatchNorm2D(out_channels)
15+
if stride != 1 or in_channels != out_channels:
16+
self.downsample = nn.Sequential(
17+
nn.Conv2D(in_channels, out_channels, 1, stride),
18+
nn.BatchNorm2D(out_channels),
19+
)
20+
else:
21+
self.downsample = None
22+
23+
def forward(self, x):
24+
identity = x
25+
out = self.relu(self.bn1(self.conv1(x)))
26+
out = self.bn2(self.conv2(out))
27+
if self.downsample is not None:
28+
identity = self.downsample(x)
29+
out += identity
30+
out = self.relu(out)
31+
return out
32+
33+
34+
class ResNet(base.Arch):
35+
"""
36+
PaddleScience风格的ResNet实现,支持自定义输入输出、层数、特征提取等。
37+
"""
38+
39+
def __init__(
40+
self,
41+
input_keys,
42+
output_keys,
43+
num_blocks=(2, 2, 2, 2), # ResNet18默认
44+
num_classes=1,
45+
in_channels=3,
46+
base_channels=64,
47+
**kwargs
48+
):
49+
super().__init__()
50+
self.input_keys = input_keys
51+
self.output_keys = output_keys
52+
53+
self.conv1 = nn.Conv2D(in_channels, base_channels, 7, 2, 3)
54+
self.bn1 = nn.BatchNorm2D(base_channels)
55+
self.relu = nn.ReLU()
56+
self.maxpool = nn.MaxPool2D(3, 2, 1)
57+
58+
self.layer1 = self._make_layer(base_channels, base_channels, num_blocks[0])
59+
self.layer2 = self._make_layer(
60+
base_channels, base_channels * 2, num_blocks[1], stride=2
61+
)
62+
self.layer3 = self._make_layer(
63+
base_channels * 2, base_channels * 4, num_blocks[2], stride=2
64+
)
65+
self.layer4 = self._make_layer(
66+
base_channels * 4, base_channels * 8, num_blocks[3], stride=2
67+
)
68+
69+
self.avgpool = nn.AdaptiveAvgPool2D((1, 1))
70+
self.fc = nn.Linear(base_channels * 8, num_classes)
71+
72+
def _make_layer(self, in_channels, out_channels, blocks, stride=1):
73+
layers = [ResNetBlock(in_channels, out_channels, stride)]
74+
for _ in range(1, blocks):
75+
layers.append(ResNetBlock(out_channels, out_channels))
76+
return nn.Sequential(*layers)
77+
78+
def forward(self, x):
79+
# x: dict, 取input_keys
80+
if isinstance(x, dict):
81+
x = x[self.input_keys[0]]
82+
x = self.conv1(x)
83+
x = self.bn1(x)
84+
x = self.relu(x)
85+
x = self.maxpool(x)
86+
x = self.layer1(x)
87+
x = self.layer2(x)
88+
x = self.layer3(x)
89+
x = self.layer4(x)
90+
x = self.avgpool(x)
91+
x = paddle.flatten(x, 1)
92+
x = self.fc(x)
93+
return x

0 commit comments

Comments
 (0)