Skip to content

Commit e14dd0c

Browse files
author
Daniel Flores
committed
to_tensor, AVIOTensorContext fix
1 parent 2117716 commit e14dd0c

File tree

2 files changed

+19
-18
lines changed

2 files changed

+19
-18
lines changed

src/torchcodec/_core/AVIOTensorContext.cpp

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,34 +18,34 @@ constexpr int64_t MAX_TENSOR_SIZE = 320'000'000; // 320 MB
1818
int read(void* opaque, uint8_t* buf, int buf_size) {
1919
auto tensorContext = static_cast<detail::TensorContext*>(opaque);
2020
TORCH_CHECK(
21-
tensorContext->current_pos <= tensorContext->data.numel(),
22-
"Tried to read outside of the buffer: current_pos=",
23-
tensorContext->current_pos,
21+
tensorContext->current <= tensorContext->data.numel(),
22+
"Tried to read outside of the buffer: current=",
23+
tensorContext->current,
2424
", size=",
2525
tensorContext->data.numel());
2626

2727
int64_t numBytesRead = std::min(
2828
static_cast<int64_t>(buf_size),
29-
tensorContext->data.numel() - tensorContext->current_pos);
29+
tensorContext->data.numel() - tensorContext->current);
3030

3131
TORCH_CHECK(
3232
numBytesRead >= 0,
3333
"Tried to read negative bytes: numBytesRead=",
3434
numBytesRead,
3535
", size=",
3636
tensorContext->data.numel(),
37-
", current_pos=",
38-
tensorContext->current_pos);
37+
", current=",
38+
tensorContext->current);
3939

4040
if (numBytesRead == 0) {
4141
return AVERROR_EOF;
4242
}
4343

4444
std::memcpy(
4545
buf,
46-
tensorContext->data.data_ptr<uint8_t>() + tensorContext->current_pos,
46+
tensorContext->data.data_ptr<uint8_t>() + tensorContext->current,
4747
numBytesRead);
48-
tensorContext->current_pos += numBytesRead;
48+
tensorContext->current += numBytesRead;
4949
return numBytesRead;
5050
}
5151

@@ -54,7 +54,7 @@ int write(void* opaque, const uint8_t* buf, int buf_size) {
5454
auto tensorContext = static_cast<detail::TensorContext*>(opaque);
5555

5656
int64_t bufSize = static_cast<int64_t>(buf_size);
57-
if (tensorContext->current_pos + bufSize > tensorContext->data.numel()) {
57+
if (tensorContext->current + bufSize > tensorContext->data.numel()) {
5858
TORCH_CHECK(
5959
tensorContext->data.numel() * 2 <= MAX_TENSOR_SIZE,
6060
"We tried to allocate an output encoded tensor larger than ",
@@ -68,17 +68,18 @@ int write(void* opaque, const uint8_t* buf, int buf_size) {
6868
}
6969

7070
TORCH_CHECK(
71-
tensorContext->current_pos + bufSize <= tensorContext->data.numel(),
71+
tensorContext->current + bufSize <= tensorContext->data.numel(),
7272
"Re-allocation of the output tensor didn't work. ",
7373
"This should not happen, please report on TorchCodec bug tracker");
7474

7575
uint8_t* outputTensorData = tensorContext->data.data_ptr<uint8_t>();
76-
std::memcpy(outputTensorData + tensorContext->current_pos, buf, bufSize);
77-
tensorContext->current_pos += bufSize;
76+
std::memcpy(outputTensorData + tensorContext->current, buf, bufSize);
77+
tensorContext->current += bufSize;
7878
// Track the maximum position written so getOutputTensor's narrow() does not
7979
// truncate the file if final seek was backwards
80-
tensorContext->max_pos =
81-
std::max(tensorContext->current_pos, tensorContext->max_pos);
80+
if (tensorContext->current > tensorContext->max) {
81+
tensorContext->max = tensorContext->current;
82+
}
8283
return buf_size;
8384
}
8485

@@ -92,7 +93,7 @@ int64_t seek(void* opaque, int64_t offset, int whence) {
9293
ret = tensorContext->data.numel();
9394
break;
9495
case SEEK_SET:
95-
tensorContext->current_pos = offset;
96+
tensorContext->current = offset;
9697
ret = offset;
9798
break;
9899
default:
@@ -124,7 +125,7 @@ AVIOToTensorContext::AVIOToTensorContext()
124125

125126
torch::Tensor AVIOToTensorContext::getOutputTensor() {
126127
return tensorContext_.data.narrow(
127-
/*dim=*/0, /*start=*/0, /*length=*/tensorContext_.max_pos);
128+
/*dim=*/0, /*start=*/0, /*length=*/tensorContext_.max);
128129
}
129130

130131
} // namespace facebook::torchcodec

src/torchcodec/_core/AVIOTensorContext.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ namespace detail {
1515

1616
struct TensorContext {
1717
torch::Tensor data;
18-
int64_t current_pos;
19-
int64_t max_pos;
18+
int64_t current;
19+
int64_t max;
2020
};
2121

2222
} // namespace detail

0 commit comments

Comments
 (0)