Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/infinicore/common/hash.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ hash_combine(size_t &seed, const T &value) {

// Specialization for Tensor
inline void hash_combine(size_t &seed, Tensor tensor) {
// For an undefined tensor (default-constructed), just mix in a sentinel
// value so that optional arguments like weight/pos_weight do not cause
// null dereferences when computing cache keys.
if (!tensor) {
hash_combine(seed, static_cast<size_t>(0));
return;
}

hash_combine(seed, static_cast<size_t>(tensor->dtype()));
for (Size shape : tensor->shape()) {
hash_combine(seed, shape);
Expand Down
5 changes: 5 additions & 0 deletions include/infinicore/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,8 @@
#include "ops/rope.hpp"
#include "ops/silu.hpp"
#include "ops/swiglu.hpp"
#include "ops/atanh.hpp"
#include "ops/addcmul.hpp"
#include "ops/cdist.hpp"
#include "ops/reciprocal.hpp"
#include "ops/binary_cross_entropy_with_logits.hpp"
17 changes: 17 additions & 0 deletions include/infinicore/ops/addcmul.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#pragma once

#include "../device.hpp"
#include "common/op.hpp"

namespace infinicore::op {

class Addcmul {
public:
// schema: out, input, t1, t2, value
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, float);
static void execute(Tensor out, Tensor input, Tensor t1, Tensor t2, float value);
static common::OpDispatcher<schema> &dispatcher();
};
Tensor addcmul(Tensor input, Tensor t1, Tensor t2, float value);
void addcmul_(Tensor out, Tensor input, Tensor t1, Tensor t2, float value);
} // namespace infinicore::op
34 changes: 34 additions & 0 deletions include/infinicore/ops/atanh.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#pragma once

#include "../device.hpp"
#include "common/op.hpp"

namespace infinicore::op {

class Atanh {
public:
// schema 定义为:void(输出 Tensor, 输入 Tensor)
using schema = void (*)(Tensor, Tensor);

// 执行函数
static void execute(Tensor y, Tensor a);

// 获取算子分发器,用于多后端(CPU/CUDA 等)匹配
static common::OpDispatcher<schema> &dispatcher();
};

/**
* @brief 计算输入 Tensor 的反双曲正切值 (out-of-place)
* @param a 输入 Tensor
* @return 包含结果的新 Tensor
*/
Tensor atanh(Tensor a);

/**
* @brief 计算输入 Tensor 的反双曲正切值 (in-place / specified output)
* @param y 输出 Tensor
* @param a 输入 Tensor
*/
void atanh_(Tensor y, Tensor a);

} // namespace infinicore::op
46 changes: 46 additions & 0 deletions include/infinicore/ops/binary_cross_entropy_with_logits.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#pragma once

#include "../device.hpp"
#include "common/op.hpp"
#include <string>

namespace infinicore::op {

class BinaryCrossEntropyWithLogits {
public:
/**
* @brief BCEWithLogits 算子的函数原型
* 参数顺序: out, logits, target, weight, pos_weight, reduction
*/
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, std::string);

static void execute(Tensor out,
Tensor logits,
Tensor target,
Tensor weight,
Tensor pos_weight,
std::string reduction);

static common::OpDispatcher<schema> &dispatcher();
};

/**
* @brief 非原地操作接口 (Out-of-place)
*/
Tensor binary_cross_entropy_with_logits(Tensor logits,
Tensor target,
Tensor weight = {},
Tensor pos_weight = {},
std::string reduction = "mean");

/**
* @brief 显式指定输出张量的接口
*/
void binary_cross_entropy_with_logits_(Tensor out,
Tensor logits,
Tensor target,
Tensor weight,
Tensor pos_weight,
std::string reduction);

} // namespace infinicore::op
32 changes: 32 additions & 0 deletions include/infinicore/ops/cdist.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#pragma once

#include "../device.hpp"
#include "common/op.hpp"

namespace infinicore::op {

class Cdist {
public:
/**
* @brief 成对距离计算算子 (Pairwise distance)
* schema: out (M, N), x1 (M, D), x2 (N, D), p (norm degree)
*/
using schema = void (*)(Tensor, Tensor, Tensor, double);

static void execute(Tensor out, Tensor x1, Tensor x2, double p);

static common::OpDispatcher<schema> &dispatcher();
};

/**
* @brief 非原地(Out-of-place)接口
* @return 返回形状为 (M, N) 的新 Tensor
*/
Tensor cdist(Tensor x1, Tensor x2, double p = 2.0);

/**
* @brief 显式指定输出接口
*/
void cdist_(Tensor out, Tensor x1, Tensor x2, double p = 2.0);

} // namespace infinicore::op
16 changes: 16 additions & 0 deletions include/infinicore/ops/reciprocal.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#pragma once

#include "../device.hpp"
#include "common/op.hpp"

namespace infinicore::op {
class Reciprocal {
public:
using schema = void (*)(Tensor, Tensor);
static void execute(Tensor y, Tensor x);
static common::OpDispatcher<schema> &dispatcher();
};

Tensor reciprocal(Tensor x);
void reciprocal_(Tensor y, Tensor x);
} // namespace infinicore::op
3 changes: 3 additions & 0 deletions include/infinicore/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ class Tensor {

operator bool() const;

// 判断 Tensor 是否已定义(是否持有有效实现)
bool is_defined() const { return static_cast<bool>(*this); }

protected:
Tensor(std::shared_ptr<TensorImpl> impl) : impl_(std::move(impl)) {}
std::shared_ptr<TensorImpl> impl_;
Expand Down
5 changes: 5 additions & 0 deletions include/infiniop.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@
#include "infiniop/ops/topkrouter.h"
#include "infiniop/ops/topksoftmax.h"
#include "infiniop/ops/zeros.h"
#include "infiniop/ops/atanh.h"
#include "infiniop/ops/addcmul.h"
#include "infiniop/ops/cdist.h"
#include "infiniop/ops/binary_cross_entropy_with_logits.h"
#include "infiniop/ops/reciprocal.h"
#include "infiniop/tensor_descriptor.h"

#endif // __INFINIOP_API_H__
57 changes: 57 additions & 0 deletions include/infiniop/ops/addcmul.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#ifndef __INFINIOP_ADDCMUL_API_H__
#define __INFINIOP_ADDCMUL_API_H__

#include "../operator_descriptor.h"

// 定义 addcmul 算子描述符类型
typedef struct InfiniopDescriptor *infiniopAddcmulDescriptor_t;

/**
* @brief 创建 Addcmul 算子描述符
* @param handle 算子句柄
* @param desc_ptr 指向返回的描述符指针
* @param out 输出张量描述符
* @param input 加项张量描述符
* @param tensor1 乘项张量1描述符
* @param tensor2 乘项张量2描述符
* @param value 乘积的标量系数
*/
__C __export infiniStatus_t infiniopCreateAddcmulDescriptor(infiniopHandle_t handle,
infiniopAddcmulDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out,
infiniopTensorDescriptor_t input,
infiniopTensorDescriptor_t tensor1,
infiniopTensorDescriptor_t tensor2,
float value);

/**
* @brief 获取 Addcmul 计算所需的临时空间大小
*/
__C __export infiniStatus_t infiniopGetAddcmulWorkspaceSize(infiniopAddcmulDescriptor_t desc, size_t *size);

/**
* @brief 执行 Addcmul 计算
* @param desc 算子描述符
* @param workspace 临时空间指针
* @param workspace_size 临时空间大小
* @param out 输出数据指针
* @param input 加项数据指针
* @param tensor1 乘项1数据指针
* @param tensor2 乘项2数据指针
* @param stream 计算流 (CUDA stream 等)
*/
__C __export infiniStatus_t infiniopAddcmul(infiniopAddcmulDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *out,
const void *input,
const void *tensor1,
const void *tensor2,
void *stream);

/**
* @brief 销毁 Addcmul 算子描述符
*/
__C __export infiniStatus_t infiniopDestroyAddcmulDescriptor(infiniopAddcmulDescriptor_t desc);

#endif
24 changes: 24 additions & 0 deletions include/infiniop/ops/atanh.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef __INFINIOP_Atanh_API_H__
#define __INFINIOP_Atanh_API_H__

#include "../operator_descriptor.h"

typedef struct InfiniopDescriptor *infiniopAtanhDescriptor_t;

__C __export infiniStatus_t infiniopCreateAtanhDescriptor(infiniopHandle_t handle,
infiniopAtanhDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y,
infiniopTensorDescriptor_t a);

__C __export infiniStatus_t infiniopGetAtanhWorkspaceSize(infiniopAtanhDescriptor_t desc, size_t *size);

__C __export infiniStatus_t infiniopAtanh(infiniopAtanhDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *y,
const void *a,
void *stream);

__C __export infiniStatus_t infiniopDestroyAtanhDescriptor(infiniopAtanhDescriptor_t desc);

#endif
73 changes: 73 additions & 0 deletions include/infiniop/ops/binary_cross_entropy_with_logits.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#ifndef __INFINIOP_BINARY_CROSS_ENTROPY_WITH_LOGITS_API_H__
#define __INFINIOP_BINARY_CROSS_ENTROPY_WITH_LOGITS_API_H__

#include "../operator_descriptor.h"

// 定义归约方式枚举
typedef enum {
INFINIOP_REDUCTION_NONE = 0,
INFINIOP_REDUCTION_MEAN = 1,
INFINIOP_REDUCTION_SUM = 2
} infiniopReduction_t;

// 定义 BCEWithLogits 算子描述符类型
typedef struct InfiniopDescriptor *infiniopBCEWithLogitsDescriptor_t;

/**
* @brief 创建 BCEWithLogits 算子描述符
* @param handle 算子句柄
* @param desc_ptr 指向返回的描述符指针
* @param out 输出张量描述符 (none时与input同形状,mean/sum时为标量)
* @param logits 输入 Logits 张量描述符
* @param target 目标标签张量描述符
* @param weight 样本权重描述符 (可选,不需要则传 NULL)
* @param pos_weight 正样本权重描述符 (可选,不需要则传 NULL)
* @param reduction 归约方式 (none, mean, sum)
*/
__C __export infiniStatus_t infiniopCreateBCEWithLogitsDescriptor(
infiniopHandle_t handle,
infiniopBCEWithLogitsDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out,
infiniopTensorDescriptor_t logits,
infiniopTensorDescriptor_t target,
infiniopTensorDescriptor_t weight,
infiniopTensorDescriptor_t pos_weight,
infiniopReduction_t reduction);

/**
* @brief 获取 BCEWithLogits 计算所需的临时空间大小
*/
__C __export infiniStatus_t infiniopGetBCEWithLogitsWorkspaceSize(
infiniopBCEWithLogitsDescriptor_t desc,
size_t *size);

/**
* @brief 执行 BCEWithLogits 计算
* @param desc 算子描述符
* @param workspace 临时空间指针
* @param workspace_size 临时空间大小
* @param out 输出数据指针
* @param logits Logits 数据指针
* @param target Target 数据指针
* @param weight 权重数据指针 (可选,传 NULL 表示权重全为 1)
* @param pos_weight 正样本权重数据指针 (可选,传 NULL 表示权重全为 1)
* @param stream 计算流
*/
__C __export infiniStatus_t infiniopBCEWithLogits(
infiniopBCEWithLogitsDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *out,
const void *logits,
const void *target,
const void *weight,
const void *pos_weight,
void *stream);

/**
* @brief 销毁 BCEWithLogits 算子描述符
*/
__C __export infiniStatus_t infiniopDestroyBCEWithLogitsDescriptor(
infiniopBCEWithLogitsDescriptor_t desc);

#endif
Loading
Loading