Skip to content

Commit 0620470

Browse files
committed
First pass on transforms. Committing to switch branches
1 parent 67b4381 commit 0620470

File tree

2 files changed

+69
-0
lines changed

2 files changed

+69
-0
lines changed

src/torchcodec/_core/Transform.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the BSD-style license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#include "src/torchcodec/_core/Transform.h"
8+
#include <torch/types.h>
9+
10+
namespace facebook::torchcodec {
11+
12+
std::string toStringFilterGraph(Transform::InterpolationMode mode) {
13+
switch (mode) {
14+
case Transform::InterpolationMode::BILINEAR:
15+
return "BILINEAR";
16+
case Transform::InterpolationMode::BICUBIC:
17+
return "BICUBIC";
18+
case Transform::InterpolationMode::NEAREST:
19+
return "NEAREST";
20+
default:
21+
TORCH_CHECK(false, "Unknown interpolation mode: " + std::to_string(mode));
22+
}
23+
}
24+
25+
std::string Transform::getFilterGraphCpu() const {
26+
return "scale=width=" + std::to_string(width_) +
27+
":height=" + std::to_string(height_) +
28+
":sws_flags=" + toStringFilterGraph(interpolationMode_);
29+
}
30+
31+
} // namespace facebook::torchcodec

src/torchcodec/_core/Transform.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the BSD-style license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
9+
#include <string>
10+
11+
namespace facebook::torchcodec {
12+
13+
class Transform {
14+
public:
15+
std::string getFilterGraphCpu() const = 0
16+
};
17+
18+
class ResizeTransform : public Transform {
19+
public:
20+
ResizeTransform(int width, int height)
21+
: width_(width),
22+
height_(height),
23+
interpolation_(InterpolationMode::BILINEAR) {}
24+
25+
ResizeTransform(int width, int height, InterpolationMode interpolation) =
26+
default;
27+
28+
std::string getFilterGraphCpu() const override;
29+
30+
enum class InterpolationMode { BILINEAR, BICUBIC, NEAREST };
31+
32+
private:
33+
int width_;
34+
int height_;
35+
InterpolationMode interpolation_;
36+
}
37+
38+
} // namespace facebook::torchcodec

0 commit comments

Comments
 (0)