Skip to content

Commit 45222dc

Browse files
author
Daniel Flores
committed
update tensorContext vars, use std::max
1 parent e14dd0c commit 45222dc

File tree

2 files changed

+18
-19
lines changed

2 files changed

+18
-19
lines changed

src/torchcodec/_core/AVIOTensorContext.cpp

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,34 +18,34 @@ constexpr int64_t MAX_TENSOR_SIZE = 320'000'000; // 320 MB
1818
int read(void* opaque, uint8_t* buf, int buf_size) {
1919
auto tensorContext = static_cast<detail::TensorContext*>(opaque);
2020
TORCH_CHECK(
21-
tensorContext->current <= tensorContext->data.numel(),
22-
"Tried to read outside of the buffer: current=",
23-
tensorContext->current,
21+
tensorContext->current_pos <= tensorContext->data.numel(),
22+
"Tried to read outside of the buffer: current_pos=",
23+
tensorContext->current_pos,
2424
", size=",
2525
tensorContext->data.numel());
2626

2727
int64_t numBytesRead = std::min(
2828
static_cast<int64_t>(buf_size),
29-
tensorContext->data.numel() - tensorContext->current);
29+
tensorContext->data.numel() - tensorContext->current_pos);
3030

3131
TORCH_CHECK(
3232
numBytesRead >= 0,
3333
"Tried to read negative bytes: numBytesRead=",
3434
numBytesRead,
3535
", size=",
3636
tensorContext->data.numel(),
37-
", current=",
38-
tensorContext->current);
37+
", current_pos=",
38+
tensorContext->current_pos);
3939

4040
if (numBytesRead == 0) {
4141
return AVERROR_EOF;
4242
}
4343

4444
std::memcpy(
4545
buf,
46-
tensorContext->data.data_ptr<uint8_t>() + tensorContext->current,
46+
tensorContext->data.data_ptr<uint8_t>() + tensorContext->current_pos,
4747
numBytesRead);
48-
tensorContext->current += numBytesRead;
48+
tensorContext->current_pos += numBytesRead;
4949
return numBytesRead;
5050
}
5151

@@ -54,7 +54,7 @@ int write(void* opaque, const uint8_t* buf, int buf_size) {
5454
auto tensorContext = static_cast<detail::TensorContext*>(opaque);
5555

5656
int64_t bufSize = static_cast<int64_t>(buf_size);
57-
if (tensorContext->current + bufSize > tensorContext->data.numel()) {
57+
if (tensorContext->current_pos + bufSize > tensorContext->data.numel()) {
5858
TORCH_CHECK(
5959
tensorContext->data.numel() * 2 <= MAX_TENSOR_SIZE,
6060
"We tried to allocate an output encoded tensor larger than ",
@@ -68,18 +68,17 @@ int write(void* opaque, const uint8_t* buf, int buf_size) {
6868
}
6969

7070
TORCH_CHECK(
71-
tensorContext->current + bufSize <= tensorContext->data.numel(),
71+
tensorContext->current_pos + bufSize <= tensorContext->data.numel(),
7272
"Re-allocation of the output tensor didn't work. ",
7373
"This should not happen, please report on TorchCodec bug tracker");
7474

7575
uint8_t* outputTensorData = tensorContext->data.data_ptr<uint8_t>();
76-
std::memcpy(outputTensorData + tensorContext->current, buf, bufSize);
77-
tensorContext->current += bufSize;
76+
std::memcpy(outputTensorData + tensorContext->current_pos, buf, bufSize);
77+
tensorContext->current_pos += bufSize;
7878
// Track the maximum position written so getOutputTensor's narrow() does not
7979
// truncate the file if final seek was backwards
80-
if (tensorContext->current > tensorContext->max) {
81-
tensorContext->max = tensorContext->current;
82-
}
80+
tensorContext->max_pos =
81+
std::max(tensorContext->current_pos, tensorContext->max_pos);
8382
return buf_size;
8483
}
8584

@@ -93,7 +92,7 @@ int64_t seek(void* opaque, int64_t offset, int whence) {
9392
ret = tensorContext->data.numel();
9493
break;
9594
case SEEK_SET:
96-
tensorContext->current = offset;
95+
tensorContext->current_pos = offset;
9796
ret = offset;
9897
break;
9998
default:
@@ -125,7 +124,7 @@ AVIOToTensorContext::AVIOToTensorContext()
125124

126125
torch::Tensor AVIOToTensorContext::getOutputTensor() {
127126
return tensorContext_.data.narrow(
128-
/*dim=*/0, /*start=*/0, /*length=*/tensorContext_.max);
127+
/*dim=*/0, /*start=*/0, /*length=*/tensorContext_.max_pos);
129128
}
130129

131130
} // namespace facebook::torchcodec

src/torchcodec/_core/AVIOTensorContext.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ namespace detail {
1515

1616
struct TensorContext {
1717
torch::Tensor data;
18-
int64_t current;
19-
int64_t max;
18+
int64_t current_pos;
19+
int64_t max_pos;
2020
};
2121

2222
} // namespace detail

0 commit comments

Comments
 (0)