-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
74 lines (53 loc) · 2.08 KB
/
main.py
File metadata and controls
74 lines (53 loc) · 2.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
import torchvision.transforms as transforms
import cv2
from timm.models.vision_transformer import TimeSformer
# # TimeSformerモデルのロード
# model = TimeSformer(img_size=224, num_classes=1000, num_frames=8, attention_type='divided_space_time')
# # モデルを評価モードに設定
# model.eval()
# モデルの設定
num_classes = 1000 # クラス数は事前学習済みモデルのものに合わせる
num_frames = 8 # フレーム数
# TimeSformerモデルの作成
model = TimeSformer(
img_size=224,
num_classes=num_classes,
num_frames=num_frames,
attention_type='divided_space_time',
)
# 事前学習済みの重みのパス
pretrained_weights_path = './models/TimeSformer_divST_8x32_224_K400.pyth'
# 保存済みの重みを読み込む
checkpoint = torch.load(pretrained_weights_path, map_location='cpu')
# モデルに重みをロード
model.load_state_dict(checkpoint['model'])
# モデルを評価モードに設定
model.eval()
# カメラの設定
cap = cv2.VideoCapture(0) # 0はデフォルトのカメラを指定
# フレームをリアルタイムで取得
while True:
ret, frame = cap.read()
# フレームが正常に取得されたら処理を行う
if ret:
# 画像の前処理(リサイズ、正規化など)
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_frame = transform(frame).unsqueeze(0)
# 推論
with torch.no_grad():
output = model(input_frame)
# ここでoutputを使用して必要な処理を行う(例:結果の表示、何かしらのアクションの実行)
# ウィンドウにフレームを表示
cv2.imshow('Real-time Inference', frame)
# 'q'を入力するとループを終了
if cv2.waitKey(1) & 0xFF == ord('q'):
break
# キャプチャを解放
cap.release()
cv2.destroyAllWindows()